diff --git a/include/array_base.h b/include/array_base.h index c53e6b12..3ad75207 100644 --- a/include/array_base.h +++ b/include/array_base.h @@ -37,9 +37,9 @@ public: } // len determines length of output buffer (default: length of input) - uint32_t* uncompress(uint32_t len=0); + uint32_t* uncompress(uint32_t len=0) const; uint32_t getSizeInBytes(); - uint32_t getLength(); + uint32_t getLength() const; }; \ No newline at end of file diff --git a/include/art.h b/include/art.h index 1047fc9a..bd887209 100644 --- a/include/art.h +++ b/include/art.h @@ -122,6 +122,8 @@ typedef struct { } art_document; enum token_ordering { + NOT_SET, + FREQUENCY, MAX_SCORE }; diff --git a/include/collection_manager.h b/include/collection_manager.h index 978cb050..0702519e 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -127,7 +127,7 @@ public: Option create_collection(const std::string& name, const size_t num_memory_shards, const std::vector & fields, - const std::string & default_sorting_field, + const std::string & default_sorting_field="", const uint64_t created_at = static_cast(std::time(nullptr)), const bool index_all_fields = false); diff --git a/include/field.h b/include/field.h index 121f5e51..811f58f1 100644 --- a/include/field.h +++ b/include/field.h @@ -164,8 +164,9 @@ struct field { } static Option fields_to_json_fields(const std::vector & fields, - const std::string & default_sorting_field, nlohmann::json& fields_json, - bool& found_default_sorting_field) { + const std::string & default_sorting_field, nlohmann::json& fields_json) { + bool found_default_sorting_field = false; + for(const field & field: fields) { nlohmann::json field_val; field_val[fields::name] = field.name; @@ -197,6 +198,11 @@ struct field { } } + if(!default_sorting_field.empty() && !found_default_sorting_field) { + return Option(400, "Default sorting field is defined as `" + default_sorting_field + + "` but is not found in the schema."); + } + return Option(true); } }; @@ -276,6 +282,7 @@ namespace sort_field_const { static const std::string asc = "ASC"; static const std::string desc = "DESC"; static const std::string text_match = "_text_match"; + static const std::string seq_id = "_seq_id"; } struct sort_by { diff --git a/include/index.h b/include/index.h index c2076b36..f0b71333 100644 --- a/include/index.h +++ b/include/index.h @@ -56,6 +56,7 @@ struct search_args { size_t typo_tokens_threshold; std::vector group_by_fields; size_t group_limit; + std::string default_sorting_field; size_t all_result_ids_len; spp::sparse_hash_set groups_processed; std::vector> searched_queries; @@ -76,14 +77,15 @@ struct search_args { std::vector sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values, size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix, size_t drop_tokens_threshold, size_t typo_tokens_threshold, - const std::vector& group_by_fields, size_t group_limit): + const std::vector& group_by_fields, size_t group_limit, + const std::string& default_sorting_field): q_include_tokens(q_include_tokens), q_exclude_tokens(q_exclude_tokens), q_synonyms(q_synonyms), search_fields(search_fields), filters(filters), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), facet_query(facet_query), num_typos(num_typos), max_facet_values(max_facet_values), per_page(per_page), page(page), token_order(token_order), prefix(prefix), drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold), - group_by_fields(group_by_fields), group_limit(group_limit), + group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field), all_result_ids_len(0) { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory @@ -169,6 +171,9 @@ private: // sort_field => (seq_id => value) spp::sparse_hash_map*> sort_index; + // this is used for wildcard queries + sorted_array seq_ids; + StringUtils string_utils; // Internal utility functions @@ -349,7 +354,8 @@ public: std::vector> & override_result_kvs, const size_t typo_tokens_threshold, const size_t group_limit, - const std::vector& group_by_fields) const; + const std::vector& group_by_fields, + const std::string& default_sorting_field) const; Option remove(const uint32_t seq_id, const nlohmann::json & document); diff --git a/src/array_base.cpp b/src/array_base.cpp index 2eae9dfd..860a1a06 100644 --- a/src/array_base.cpp +++ b/src/array_base.cpp @@ -1,6 +1,6 @@ #include "array_base.h" -uint32_t* array_base::uncompress(uint32_t len) { +uint32_t* array_base::uncompress(uint32_t len) const { uint32_t actual_len = std::max(len, length); uint32_t *out = new uint32_t[actual_len]; for_uncompress(in, out, length); @@ -11,6 +11,6 @@ uint32_t array_base::getSizeInBytes() { return size_bytes; } -uint32_t array_base::getLength() { +uint32_t array_base::getLength() const { return length; } diff --git a/src/collection.cpp b/src/collection.cpp index 91b29bed..9c3c8a1a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -491,7 +491,7 @@ Option Collection::search(const std::string & query, const std:: const std::string & simple_filter_query, const std::vector& facet_fields, const std::vector & sort_fields, const int num_typos, const size_t per_page, const size_t page, - const token_ordering token_order, const bool prefix, + token_ordering token_order, const bool prefix, const size_t drop_tokens_threshold, const spp::sparse_hash_set & include_fields, const spp::sparse_hash_set & exclude_fields, @@ -755,7 +755,11 @@ Option Collection::search(const std::string & query, const std:: */ if(sort_fields_std.empty()) { sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc); - sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc); + if(!default_sorting_field.empty()) { + sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc); + } else { + sort_fields_std.emplace_back(sort_field_const::seq_id, sort_field_const::desc); + } } bool found_match_score = false; @@ -801,6 +805,14 @@ Option Collection::search(const std::string & query, const std:: max_hits = std::min(std::max((page * per_page), max_hits), get_num_documents()); } + if(token_order == NOT_SET) { + if(default_sorting_field.empty()) { + token_order = FREQUENCY; + } else { + token_order = MAX_SCORE; + } + } + std::vector> searched_queries; // search queries used for generating the results std::vector> raw_result_kvs; std::vector> override_result_kvs; @@ -833,7 +845,7 @@ Option Collection::search(const std::string & query, const std:: sort_fields_std, facet_query, num_typos, max_facet_values, max_hits, per_page, page, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold, - group_by_fields, group_limit); + group_by_fields, group_limit, default_sorting_field); search_args_vec.push_back(search_params); @@ -2264,11 +2276,9 @@ Option Collection::check_and_update_schema(nlohmann::json& document) { try { collection_meta = nlohmann::json::parse(coll_meta_json); - bool found_default_sorting_field = false; nlohmann::json fields_json = nlohmann::json::array();; - Option fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json, - found_default_sorting_field); + Option fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json); if(!fields_json_op.ok()) { return Option(fields_json_op.code(), fields_json_op.error()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 4a25e9b6..411a59e7 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -304,7 +304,7 @@ bool CollectionManager::auth_key_matches(const std::string& auth_key_sent, Option CollectionManager::create_collection(const std::string& name, const size_t num_memory_shards, const std::vector & fields, - const std::string & default_sorting_field, + const std::string& default_sorting_field, const uint64_t created_at, const bool index_all_fields) { std::unique_lock lock(mutex); @@ -313,21 +313,14 @@ Option CollectionManager::create_collection(const std::string& name return Option(409, std::string("A collection with name `") + name + "` already exists."); } - bool found_default_sorting_field = false; nlohmann::json fields_json = nlohmann::json::array();; - Option fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json, - found_default_sorting_field); + Option fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json); if(!fields_json_op.ok()) { return Option(fields_json_op.code(), fields_json_op.error()); } - if(!found_default_sorting_field) { - return Option(400, "Default sorting field is defined as `" + default_sorting_field + - "` but is not found in the schema."); - } - nlohmann::json collection_meta; collection_meta[Collection::COLLECTION_NAME_KEY] = name; collection_meta[Collection::COLLECTION_ID_KEY] = next_collection_id.load(); @@ -765,12 +758,16 @@ Option CollectionManager::do_search(std::map& re const size_t drop_tokens_threshold = (size_t) std::stoi(req_params[DROP_TOKENS_THRESHOLD]); const size_t typo_tokens_threshold = (size_t) std::stoi(req_params[TYPO_TOKENS_THRESHOLD]); - if(req_params.count(RANK_TOKENS_BY) == 0) { - req_params[RANK_TOKENS_BY] = "DEFAULT_SORTING_FIELD"; - } + token_ordering token_order = NOT_SET; - StringUtils::toupper(req_params[RANK_TOKENS_BY]); - token_ordering token_order = (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") ? MAX_SCORE : FREQUENCY; + if(req_params.count(RANK_TOKENS_BY) != 0) { + StringUtils::toupper(req_params[RANK_TOKENS_BY]); + if (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") { + token_order = MAX_SCORE; + } else if(req_params[RANK_TOKENS_BY] == "FREQUENCY") { + token_order = FREQUENCY; + } + } Option result_op = collection->search(req_params[QUERY], search_fields, filter_str, facet_fields, sort_fields, std::stoi(req_params[NUM_TYPOS]), diff --git a/src/core_api.cpp b/src/core_api.cpp index d37da108..34650c25 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -108,8 +108,7 @@ bool post_create_collection(http_req & req, http_res & res) { const char* DEFAULT_SORTING_FIELD = "default_sorting_field"; if(req_json.count(DEFAULT_SORTING_FIELD) == 0) { - res.set_400("Parameter `default_sorting_field` is required."); - return false; + req_json[DEFAULT_SORTING_FIELD] = ""; } if(!req_json[DEFAULT_SORTING_FIELD].is_string()) { diff --git a/src/index.cpp b/src/index.cpp index ecacf9df..b2f2c3b1 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -67,16 +67,14 @@ Index::~Index() { int64_t Index::get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field) { int64_t points = 0; - if(!default_sorting_field.empty()) { - if(document[default_sorting_field].is_number_float()) { - // serialize float to an integer and reverse the inverted range - float n = document[default_sorting_field]; - memcpy(&points, &n, sizeof(int32_t)); - points ^= ((points >> (std::numeric_limits::digits - 1)) | INT32_MIN); - points = -1 * (INT32_MAX - points); - } else { - points = document[default_sorting_field]; - } + if(document[default_sorting_field].is_number_float()) { + // serialize float to an integer and reverse the inverted range + float n = document[default_sorting_field]; + memcpy(&points, &n, sizeof(int32_t)); + points ^= ((points >> (std::numeric_limits::digits - 1)) | INT32_MIN); + points = -1 * (INT32_MAX - points); + } else { + points = document[default_sorting_field]; } return points; @@ -99,12 +97,20 @@ Option Index::index_in_memory(const nlohmann::json &document, uint32_t int64_t points = 0; - if(is_update && document.count(default_sorting_field) == 0) { - points = sort_index[default_sorting_field]->at(seq_id); + if(document.count(default_sorting_field) == 0) { + if(sort_index.count(default_sorting_field) != 0 && sort_index[default_sorting_field]->count(seq_id)) { + points = sort_index[default_sorting_field]->at(seq_id); + } else { + points = INT64_MIN; + } } else { points = get_points_from_doc(document, default_sorting_field); } + if(!is_update) { + seq_ids.append(seq_id); + } + std::unordered_map facet_to_id; size_t i_facet = 0; for(const auto & facet: facet_schema) { @@ -266,23 +272,13 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 bool index_all_fields, const DIRTY_VALUES& dirty_values) { - bool has_default_sort_field = (document.count(default_sorting_field) != 0); + bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0); - if(!has_default_sort_field && !is_update) { + if(!is_update && missing_default_sort_field) { return Option<>(400, "Field `" + default_sorting_field + "` has been declared as a default sorting field, " "but is not found in the document."); } - if(has_default_sort_field && - !document[default_sorting_field].is_number_integer() && !document[default_sorting_field].is_number_float()) { - return Option<>(400, "Default sorting field `" + default_sorting_field + "` must be a single valued numerical field."); - } - - if(has_default_sort_field && search_schema.at(default_sorting_field).is_single_float() && - document[default_sorting_field].get() > std::numeric_limits::max()) { - return Option<>(400, "Default sorting field `" + default_sorting_field + "` exceeds maximum value of a float."); - } - for(const auto& field_pair: search_schema) { const std::string& field_name = field_pair.first; @@ -317,18 +313,18 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 } } } else if(field_pair.second.type == field_types::INT64 && !document[field_name].is_number_integer()) { - Option coerce_op = coerce_int64_t(dirty_values, document, field_name, false); + Option coerce_op = coerce_int64_t(dirty_values, document, field_name, -1); if(!coerce_op.ok()) { return coerce_op; } } else if(field_pair.second.type == field_types::FLOAT && !document[field_name].is_number()) { // using `is_number` allows integer to be passed to a float field - Option coerce_op = coerce_float(dirty_values, document, field_name, false); + Option coerce_op = coerce_float(dirty_values, document, field_name, -1); if(!coerce_op.ok()) { return coerce_op; } } else if(field_pair.second.type == field_types::BOOL && !document[field_name].is_boolean()) { - Option coerce_op = coerce_bool(dirty_values, document, field_name, false); + Option coerce_op = coerce_bool(dirty_values, document, field_name, -1); if(!coerce_op.ok()) { return coerce_op; } @@ -356,18 +352,18 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 return coerce_op; } } else if (field_pair.second.type == field_types::INT64_ARRAY && !item.is_number_integer()) { - Option coerce_op = coerce_int64_t(dirty_values, document, field_name, true); + Option coerce_op = coerce_int64_t(dirty_values, document, field_name, arr_index); if (!coerce_op.ok()) { return coerce_op; } } else if (field_pair.second.type == field_types::FLOAT_ARRAY && !item.is_number()) { // we check for `is_number` to allow whole numbers to be passed into float fields - Option coerce_op = coerce_float(dirty_values, document, field_name, true); + Option coerce_op = coerce_float(dirty_values, document, field_name, arr_index); if (!coerce_op.ok()) { return coerce_op; } } else if (field_pair.second.type == field_types::BOOL_ARRAY && !item.is_boolean()) { - Option coerce_op = coerce_bool(dirty_values, document, field_name, true); + Option coerce_op = coerce_bool(dirty_values, document, field_name, arr_index); if (!coerce_op.ok()) { return coerce_op; } @@ -1335,7 +1331,8 @@ void Index::run_search(search_args* search_params) { search_params->searched_queries, search_params->raw_result_kvs, search_params->override_result_kvs, search_params->typo_tokens_threshold, - search_params->group_limit, search_params->group_by_fields); + search_params->group_limit, search_params->group_by_fields, + search_params->default_sorting_field); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -1427,7 +1424,8 @@ void Index::search(const std::vector& q_include_tokens, std::vector> & override_result_kvs, const size_t typo_tokens_threshold, const size_t group_limit, - const std::vector& group_by_fields) const { + const std::vector& group_by_fields, + const std::string& default_sorting_field) const { std::shared_lock lock(mutex); @@ -1491,27 +1489,22 @@ void Index::search(const std::vector& q_include_tokens, // if a filter is not specified, use the sorting index to generate the list of all document ids if(filters.empty()) { - std::string all_records_field; + if(default_sorting_field.empty()) { + filter_ids_length = seq_ids.getLength(); + filter_ids = seq_ids.uncompress(); + } else { + const spp::sparse_hash_map *kvs = sort_index.at(default_sorting_field); + filter_ids_length = kvs->size(); + filter_ids = new uint32_t[filter_ids_length]; - // get the first non-optional field - for(const auto& kv: sort_schema) { - if(!kv.second.optional && kv.first != sort_field_const::text_match) { - all_records_field = kv.first; - break; + size_t i = 0; + for(const auto& kv: *kvs) { + filter_ids[i++] = kv.first; } + + // ids populated from hash map will not be sorted, but sorting is required for intersection & other ops + std::sort(filter_ids, filter_ids+filter_ids_length); } - - const spp::sparse_hash_map *kvs = sort_index.at(all_records_field); - filter_ids_length = kvs->size(); - filter_ids = new uint32_t[filter_ids_length]; - - size_t i = 0; - for(const auto& kv: *kvs) { - filter_ids[i++] = kv.first; - } - - // ids populated from hash map will not be sorted, but sorting is required for intersection & other ops - std::sort(filter_ids, filter_ids+filter_ids_length); } if(!curated_ids.empty()) { @@ -1919,7 +1912,7 @@ void Index::score_results(const std::vector & sort_fields, const uint16 const size_t group_limit, const std::vector& group_by_fields, uint32_t token_bits) const { - std::vector leaf_to_indices; + std::vector leaf_to_indices; for (art_leaf *token_leaf: query_suggestion) { uint32_t *indices = new uint32_t[result_size]; token_leaf->values->ids.indexOf(result_ids, result_size, indices); @@ -1937,19 +1930,25 @@ void Index::score_results(const std::vector & sort_fields, const uint16 spp::sparse_hash_map geopoint_distances[3]; - for(size_t i = 0; i < sort_fields.size(); i++) { + spp::sparse_hash_map text_match_sentinel_value, seq_id_sentinel_value; + spp::sparse_hash_map *TEXT_MATCH_SENTINEL = &text_match_sentinel_value; + spp::sparse_hash_map *SEQ_ID_SENTINEL = &seq_id_sentinel_value; + + for (size_t i = 0; i < sort_fields.size(); i++) { sort_order[i] = 1; - if(sort_fields[i].order == sort_field_const::asc) { + if (sort_fields[i].order == sort_field_const::asc) { sort_order[i] = -1; } - if(sort_fields[i].name == sort_field_const::text_match) { - field_values[i] = nullptr; - } else if(sort_schema.at(sort_fields[i].name).is_geopoint()) { + if (sort_fields[i].name == sort_field_const::text_match) { + field_values[i] = TEXT_MATCH_SENTINEL; + } else if (sort_fields[i].name == sort_field_const::seq_id) { + field_values[i] = SEQ_ID_SENTINEL; + } else if (sort_schema.at(sort_fields[i].name).is_geopoint()) { // we have to populate distances that will be used for match scoring - spp::sparse_hash_map* geopoints = sort_index.at(sort_fields[i].name); + spp::sparse_hash_map *geopoints = sort_index.at(sort_fields[i].name); - for(size_t rindex=0; rindexfind(seq_id); int64_t dist = (it == geopoints->end()) ? INT32_MAX : h3Distance(sort_fields[i].geopoint, it->second); @@ -1964,23 +1963,23 @@ void Index::score_results(const std::vector & sort_fields, const uint16 //auto begin = std::chrono::high_resolution_clock::now(); - for(size_t i=0; i>> array_token_positions; populate_token_positions(query_suggestion, leaf_to_indices, i, array_token_positions); - for(const auto& kv: array_token_positions) { - const std::vector>& token_positions = kv.second; - if(token_positions.empty()) { + for (const auto& kv: array_token_positions) { + const std::vector> &token_positions = kv.second; + if (token_positions.empty()) { continue; } - const Match & match = Match(seq_id, token_positions, false); + const Match &match = Match(seq_id, token_positions, false); uint64_t this_match_score = match.get_match_score(total_cost); match_score += this_match_score; @@ -2000,40 +1999,49 @@ void Index::score_results(const std::vector & sort_fields, const uint16 size_t match_score_index = 0; // avoiding loop - if(sort_fields.size() > 0) { - if (field_values[0] != nullptr) { - auto it = field_values[0]->find(seq_id); - scores[0] = (it == field_values[0]->end()) ? default_score : it->second; - } else { + if (sort_fields.size() > 0) { + if (field_values[0] == TEXT_MATCH_SENTINEL) { scores[0] = int64_t(match_score); match_score_index = 0; + } else if (field_values[0] == SEQ_ID_SENTINEL) { + scores[0] = seq_id; + } else { + auto it = field_values[0]->find(seq_id); + scores[0] = (it == field_values[0]->end()) ? default_score : it->second; } if (sort_order[0] == -1) { scores[0] = -scores[0]; } } + if(sort_fields.size() > 1) { - if (field_values[1] != nullptr) { - auto it = field_values[1]->find(seq_id); - scores[1] = (it == field_values[1]->end()) ? default_score : it->second; - } else { + if (field_values[1] == TEXT_MATCH_SENTINEL) { scores[1] = int64_t(match_score); match_score_index = 1; + } else if (field_values[1] == SEQ_ID_SENTINEL) { + scores[1] = seq_id; + } else { + auto it = field_values[1]->find(seq_id); + scores[1] = (it == field_values[1]->end()) ? default_score : it->second; } + if (sort_order[1] == -1) { scores[1] = -scores[1]; } } if(sort_fields.size() > 2) { - if(field_values[2] != nullptr) { - auto it = field_values[2]->find(seq_id); - scores[2] = (it == field_values[2]->end()) ? default_score : it->second; - } else { + if(field_values[2] != TEXT_MATCH_SENTINEL) { scores[2] = int64_t(match_score); match_score_index = 2; + } else if (field_values[2] == SEQ_ID_SENTINEL) { + scores[2] = seq_id; + } else { + auto it = field_values[2]->find(seq_id); + scores[2] = (it == field_values[2]->end()) ? default_score : it->second; } + if(sort_order[2] == -1) { scores[2] = -scores[2]; } @@ -2314,6 +2322,8 @@ Option Index::remove(const uint32_t seq_id, const nlohmann::json & doc } } + seq_ids.remove_value(seq_id); + return Option(seq_id); } diff --git a/test/collection_faceting_test.cpp b/test/collection_faceting_test.cpp index ea63463a..0204ea6d 100644 --- a/test/collection_faceting_test.cpp +++ b/test/collection_faceting_test.cpp @@ -780,10 +780,6 @@ TEST_F(CollectionFacetingTest, FacetCountOnSimilarStrings) { token_ordering::FREQUENCY, true, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10).get(); - LOG(INFO) << results; - - return; - ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ(2, results["facet_counts"][0]["counts"].size()); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 8ce28d31..d083c3b7 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -304,6 +304,13 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) { ASSERT_EQ(1, coll1->get_collection_id()); ASSERT_EQ(3, coll1->get_sort_fields().size()); + // index a document with a bad field value with COERCE_OR_IGNORE setting + auto doc_json = R"({"title": "Unique record.", "max": 25, "scores": [22, "how", 44], + "average": "bad data", "is_valid": true})"; + + Option add_op = coll1->add(doc_json, CREATE, "", DIRTY_VALUES::COERCE_OR_IGNORE); + ASSERT_TRUE(add_op.ok()); + std::unordered_map schema = collection1->get_schema(); // create a new collection manager to ensure that it restores the records from the disk backed store @@ -324,7 +331,8 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) { auto restored_schema = restored_coll->get_schema(); ASSERT_EQ(1, restored_coll->get_collection_id()); - ASSERT_EQ(6, restored_coll->get_next_seq_id()); + ASSERT_EQ(7, restored_coll->get_next_seq_id()); + ASSERT_EQ(7, restored_coll->get_num_documents()); ASSERT_EQ(facet_fields_expected, restored_coll->get_facet_fields()); ASSERT_EQ(3, restored_coll->get_sort_fields().size()); ASSERT_EQ("is_valid", restored_coll->get_sort_fields()[0].name); @@ -347,6 +355,24 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) { ASSERT_FALSE(kv.second.optional); } + // try searching for record with bad data + auto results = restored_coll->search("unique", {"title"}, "", {}, {}, 0, 10, 1, FREQUENCY, false).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("Unique record.", results["hits"][0]["document"]["title"].get().c_str()); + ASSERT_EQ(0, results["hits"][0]["document"].count("average")); + ASSERT_EQ(2, results["hits"][0]["document"]["scores"].size()); + ASSERT_EQ(22, results["hits"][0]["document"]["scores"][0]); + ASSERT_EQ(44, results["hits"][0]["document"]["scores"][1]); + + // try sorting on `average`, a field that not all records have + ASSERT_EQ(7, restored_coll->get_num_documents()); + + sort_fields = { sort_by("average", "DESC") }; + results = restored_coll->search("*", {"title"}, "", {}, {sort_fields}, 0, 10, 1, FREQUENCY, false).get(); + + ASSERT_EQ(7, results["hits"].size()); + collectionManager.drop_collection("coll1"); collectionManager2.drop_collection("coll1"); } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index abd904cc..99f4804c 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -129,19 +129,114 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) { std::vector sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; Option collection_op = collectionManager.create_collection("sample_collection", 4, fields, "name"); - EXPECT_FALSE(collection_op.ok()); - EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error()); + ASSERT_FALSE(collection_op.ok()); + ASSERT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error()); collectionManager.drop_collection("sample_collection"); // Default sorting field must exist as a field in schema sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; collection_op = collectionManager.create_collection("sample_collection", 4, fields, "NOT-DEFINED"); - EXPECT_FALSE(collection_op.ok()); - EXPECT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error()); + ASSERT_FALSE(collection_op.ok()); + ASSERT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error()); collectionManager.drop_collection("sample_collection"); } +TEST_F(CollectionSortingTest, NoDefaultSortingField) { + Collection *coll1; + + std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl"); + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 4, fields).get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + coll1->add(json_line); + } + + infile.close(); + + // without a default sorting field, matches should be sorted by (text_match, seq_id) + auto results = coll1->search("rocket", {"title"}, "", {}, {}, 1, 10, 1, FREQUENCY, false).get(); + + ASSERT_EQ(5, results["found"]); + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ(24, results["out_of"]); + + std::vector ids = {"16", "15", "7", "0", "22"}; + + for(size_t i=0; i < results["hits"].size(); i++) { + ASSERT_EQ(ids[i], results["hits"][i]["document"]["id"].get()); + } + + // try removing a document and doing wildcard (tests the seq_id array used for wildcard searches) + auto remove_op = coll1->remove("0"); + ASSERT_TRUE(remove_op.ok()); + + results = coll1->search("*", {}, "", {}, {}, 1, 30, 1, FREQUENCY, false).get(); + + ASSERT_EQ(23, results["found"]); + ASSERT_EQ(23, results["hits"].size()); + ASSERT_EQ(23, results["out_of"]); + + for(size_t i=23; i >= 1; i--) { + std::string doc_id = (i == 4) ? "foo" : std::to_string(i); + ASSERT_EQ(doc_id, results["hits"][23 - i]["document"]["id"].get()); + } +} + +TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField) { + // when no default sorting field is provided, tokens must be ordered on frequency + Collection *coll1; + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields).get(); + } + + // since only top 10 tokens are fetched for prefixes, the "end" should not show up in the results + std::vector tokens = { + "enter", "elephant", "enamel", "ercot", "enyzme", "energy", + "epoch", "epyc", "express", "everest", "end" + }; + + for(size_t i = 0; i < tokens.size(); i++) { + size_t num_repeat = tokens.size() - i; + + std::string title = tokens[i]; + + for(size_t j = 0; j < num_repeat; j++) { + nlohmann::json doc; + doc["title"] = title; + doc["points"] = num_repeat; + coll1->add(doc.dump()); + } + } + + auto results = coll1->search("e", {"title"}, "", {}, {}, 0, 100, 1, NOT_SET, true).get(); + + // 11 + 10 + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + ASSERT_EQ(65, results["found"]); + + // we have to ensure that no result contains the word "end" since it occurs least number of times + bool found_end = false; + for(auto& res: results["hits"].items()) { + if(res.value()["document"]["title"] == "end") { + found_end = true; + } + } + + ASSERT_FALSE(found_end); +} + TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) { Collection *coll_mul_fields; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index b66c7f47..7a506a99 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -480,8 +480,6 @@ TEST_F(CollectionTest, WildcardQuery) { nlohmann::json results = collection->search("*", query_fields, "points:>0", {}, sort_fields, 0, 3, 1, FREQUENCY, false).get(); - LOG(INFO) << results; - ASSERT_EQ(3, results["hits"].size()); ASSERT_EQ(25, results["found"].get()); @@ -1639,10 +1637,10 @@ TEST_F(CollectionTest, IndexingWithBadData) { const Option & empty_facet_field_op = sample_collection->add(doc_str); ASSERT_TRUE(empty_facet_field_op.ok()); - doc_str = "{\"name\": \"foo\", \"age\": \"34\", \"tags\": [], \"average\": 34 }"; + doc_str = "{\"name\": \"foo\", \"age\": [\"34\"], \"tags\": [], \"average\": 34 }"; const Option & bad_default_sorting_field_op1 = sample_collection->add(doc_str); ASSERT_FALSE(bad_default_sorting_field_op1.ok()); - ASSERT_STREQ("Default sorting field `age` must be a single valued numerical field.", bad_default_sorting_field_op1.error().c_str()); + ASSERT_STREQ("Field `age` must be an int32.", bad_default_sorting_field_op1.error().c_str()); doc_str = "{\"name\": \"foo\", \"tags\": [], \"average\": 34 }"; const Option & bad_default_sorting_field_op3 = sample_collection->add(doc_str); @@ -2691,8 +2689,6 @@ TEST_F(CollectionTest, MultiFieldRelevance) { auto results = coll1->search("Dustin Kensrue Down There by the Train", {"title", "artist"}, "", {}, {}, 0, 10, 1, FREQUENCY).get(); - LOG(INFO) << results; - ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size());