Make default sorting field optional.

This commit is contained in:
kishorenc 2021-02-21 19:55:31 +05:30
parent 222ab345be
commit d2a825799b
14 changed files with 268 additions and 124 deletions

View File

@ -37,9 +37,9 @@ public:
}
// len determines length of output buffer (default: length of input)
uint32_t* uncompress(uint32_t len=0);
uint32_t* uncompress(uint32_t len=0) const;
uint32_t getSizeInBytes();
uint32_t getLength();
uint32_t getLength() const;
};

View File

@ -122,6 +122,8 @@ typedef struct {
} art_document;
enum token_ordering {
NOT_SET,
FREQUENCY,
MAX_SCORE
};

View File

@ -127,7 +127,7 @@ public:
Option<Collection*> create_collection(const std::string& name, const size_t num_memory_shards,
const std::vector<field> & fields,
const std::string & default_sorting_field,
const std::string & default_sorting_field="",
const uint64_t created_at = static_cast<uint64_t>(std::time(nullptr)),
const bool index_all_fields = false);

View File

@ -164,8 +164,9 @@ struct field {
}
static Option<bool> fields_to_json_fields(const std::vector<field> & fields,
const std::string & default_sorting_field, nlohmann::json& fields_json,
bool& found_default_sorting_field) {
const std::string & default_sorting_field, nlohmann::json& fields_json) {
bool found_default_sorting_field = false;
for(const field & field: fields) {
nlohmann::json field_val;
field_val[fields::name] = field.name;
@ -197,6 +198,11 @@ struct field {
}
}
if(!default_sorting_field.empty() && !found_default_sorting_field) {
return Option<bool>(400, "Default sorting field is defined as `" + default_sorting_field +
"` but is not found in the schema.");
}
return Option<bool>(true);
}
};
@ -276,6 +282,7 @@ namespace sort_field_const {
static const std::string asc = "ASC";
static const std::string desc = "DESC";
static const std::string text_match = "_text_match";
static const std::string seq_id = "_seq_id";
}
struct sort_by {

View File

@ -56,6 +56,7 @@ struct search_args {
size_t typo_tokens_threshold;
std::vector<std::string> group_by_fields;
size_t group_limit;
std::string default_sorting_field;
size_t all_result_ids_len;
spp::sparse_hash_set<uint64_t> groups_processed;
std::vector<std::vector<art_leaf*>> searched_queries;
@ -76,14 +77,15 @@ struct search_args {
std::vector<sort_by> sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values,
size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix,
size_t drop_tokens_threshold, size_t typo_tokens_threshold,
const std::vector<std::string>& group_by_fields, size_t group_limit):
const std::vector<std::string>& group_by_fields, size_t group_limit,
const std::string& default_sorting_field):
q_include_tokens(q_include_tokens), q_exclude_tokens(q_exclude_tokens), q_synonyms(q_synonyms),
search_fields(search_fields), filters(filters), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
facet_query(facet_query), num_typos(num_typos), max_facet_values(max_facet_values), per_page(per_page),
page(page), token_order(token_order), prefix(prefix),
drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold),
group_by_fields(group_by_fields), group_limit(group_limit),
group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field),
all_result_ids_len(0) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
@ -169,6 +171,9 @@ private:
// sort_field => (seq_id => value)
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t>*> sort_index;
// this is used for wildcard queries
sorted_array seq_ids;
StringUtils string_utils;
// Internal utility functions
@ -349,7 +354,8 @@ public:
std::vector<std::vector<KV*>> & override_result_kvs,
const size_t typo_tokens_threshold,
const size_t group_limit,
const std::vector<std::string>& group_by_fields) const;
const std::vector<std::string>& group_by_fields,
const std::string& default_sorting_field) const;
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document);

View File

@ -1,6 +1,6 @@
#include "array_base.h"
uint32_t* array_base::uncompress(uint32_t len) {
uint32_t* array_base::uncompress(uint32_t len) const {
uint32_t actual_len = std::max(len, length);
uint32_t *out = new uint32_t[actual_len];
for_uncompress(in, out, length);
@ -11,6 +11,6 @@ uint32_t array_base::getSizeInBytes() {
return size_bytes;
}
uint32_t array_base::getLength() {
uint32_t array_base::getLength() const {
return length;
}

View File

@ -491,7 +491,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const std::string & simple_filter_query, const std::vector<std::string>& facet_fields,
const std::vector<sort_by> & sort_fields, const int num_typos,
const size_t per_page, const size_t page,
const token_ordering token_order, const bool prefix,
token_ordering token_order, const bool prefix,
const size_t drop_tokens_threshold,
const spp::sparse_hash_set<std::string> & include_fields,
const spp::sparse_hash_set<std::string> & exclude_fields,
@ -755,7 +755,11 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
*/
if(sort_fields_std.empty()) {
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
if(!default_sorting_field.empty()) {
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
} else {
sort_fields_std.emplace_back(sort_field_const::seq_id, sort_field_const::desc);
}
}
bool found_match_score = false;
@ -801,6 +805,14 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
max_hits = std::min(std::max((page * per_page), max_hits), get_num_documents());
}
if(token_order == NOT_SET) {
if(default_sorting_field.empty()) {
token_order = FREQUENCY;
} else {
token_order = MAX_SCORE;
}
}
std::vector<std::vector<art_leaf*>> searched_queries; // search queries used for generating the results
std::vector<std::vector<KV*>> raw_result_kvs;
std::vector<std::vector<KV*>> override_result_kvs;
@ -833,7 +845,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
sort_fields_std, facet_query, num_typos, max_facet_values, max_hits,
per_page, page, token_order, prefix,
drop_tokens_threshold, typo_tokens_threshold,
group_by_fields, group_limit);
group_by_fields, group_limit, default_sorting_field);
search_args_vec.push_back(search_params);
@ -2264,11 +2276,9 @@ Option<bool> Collection::check_and_update_schema(nlohmann::json& document) {
try {
collection_meta = nlohmann::json::parse(coll_meta_json);
bool found_default_sorting_field = false;
nlohmann::json fields_json = nlohmann::json::array();;
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json,
found_default_sorting_field);
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json);
if(!fields_json_op.ok()) {
return Option<bool>(fields_json_op.code(), fields_json_op.error());

View File

@ -304,7 +304,7 @@ bool CollectionManager::auth_key_matches(const std::string& auth_key_sent,
Option<Collection*> CollectionManager::create_collection(const std::string& name,
const size_t num_memory_shards,
const std::vector<field> & fields,
const std::string & default_sorting_field,
const std::string& default_sorting_field,
const uint64_t created_at,
const bool index_all_fields) {
std::unique_lock lock(mutex);
@ -313,21 +313,14 @@ Option<Collection*> CollectionManager::create_collection(const std::string& name
return Option<Collection*>(409, std::string("A collection with name `") + name + "` already exists.");
}
bool found_default_sorting_field = false;
nlohmann::json fields_json = nlohmann::json::array();;
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json,
found_default_sorting_field);
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json);
if(!fields_json_op.ok()) {
return Option<Collection*>(fields_json_op.code(), fields_json_op.error());
}
if(!found_default_sorting_field) {
return Option<Collection*>(400, "Default sorting field is defined as `" + default_sorting_field +
"` but is not found in the schema.");
}
nlohmann::json collection_meta;
collection_meta[Collection::COLLECTION_NAME_KEY] = name;
collection_meta[Collection::COLLECTION_ID_KEY] = next_collection_id.load();
@ -765,12 +758,16 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const size_t drop_tokens_threshold = (size_t) std::stoi(req_params[DROP_TOKENS_THRESHOLD]);
const size_t typo_tokens_threshold = (size_t) std::stoi(req_params[TYPO_TOKENS_THRESHOLD]);
if(req_params.count(RANK_TOKENS_BY) == 0) {
req_params[RANK_TOKENS_BY] = "DEFAULT_SORTING_FIELD";
}
token_ordering token_order = NOT_SET;
StringUtils::toupper(req_params[RANK_TOKENS_BY]);
token_ordering token_order = (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") ? MAX_SCORE : FREQUENCY;
if(req_params.count(RANK_TOKENS_BY) != 0) {
StringUtils::toupper(req_params[RANK_TOKENS_BY]);
if (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") {
token_order = MAX_SCORE;
} else if(req_params[RANK_TOKENS_BY] == "FREQUENCY") {
token_order = FREQUENCY;
}
}
Option<nlohmann::json> result_op = collection->search(req_params[QUERY], search_fields, filter_str, facet_fields,
sort_fields, std::stoi(req_params[NUM_TYPOS]),

View File

@ -108,8 +108,7 @@ bool post_create_collection(http_req & req, http_res & res) {
const char* DEFAULT_SORTING_FIELD = "default_sorting_field";
if(req_json.count(DEFAULT_SORTING_FIELD) == 0) {
res.set_400("Parameter `default_sorting_field` is required.");
return false;
req_json[DEFAULT_SORTING_FIELD] = "";
}
if(!req_json[DEFAULT_SORTING_FIELD].is_string()) {

View File

@ -67,16 +67,14 @@ Index::~Index() {
int64_t Index::get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field) {
int64_t points = 0;
if(!default_sorting_field.empty()) {
if(document[default_sorting_field].is_number_float()) {
// serialize float to an integer and reverse the inverted range
float n = document[default_sorting_field];
memcpy(&points, &n, sizeof(int32_t));
points ^= ((points >> (std::numeric_limits<int32_t>::digits - 1)) | INT32_MIN);
points = -1 * (INT32_MAX - points);
} else {
points = document[default_sorting_field];
}
if(document[default_sorting_field].is_number_float()) {
// serialize float to an integer and reverse the inverted range
float n = document[default_sorting_field];
memcpy(&points, &n, sizeof(int32_t));
points ^= ((points >> (std::numeric_limits<int32_t>::digits - 1)) | INT32_MIN);
points = -1 * (INT32_MAX - points);
} else {
points = document[default_sorting_field];
}
return points;
@ -99,12 +97,20 @@ Option<uint32_t> Index::index_in_memory(const nlohmann::json &document, uint32_t
int64_t points = 0;
if(is_update && document.count(default_sorting_field) == 0) {
points = sort_index[default_sorting_field]->at(seq_id);
if(document.count(default_sorting_field) == 0) {
if(sort_index.count(default_sorting_field) != 0 && sort_index[default_sorting_field]->count(seq_id)) {
points = sort_index[default_sorting_field]->at(seq_id);
} else {
points = INT64_MIN;
}
} else {
points = get_points_from_doc(document, default_sorting_field);
}
if(!is_update) {
seq_ids.append(seq_id);
}
std::unordered_map<std::string, size_t> facet_to_id;
size_t i_facet = 0;
for(const auto & facet: facet_schema) {
@ -266,23 +272,13 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
bool index_all_fields,
const DIRTY_VALUES& dirty_values) {
bool has_default_sort_field = (document.count(default_sorting_field) != 0);
bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0);
if(!has_default_sort_field && !is_update) {
if(!is_update && missing_default_sort_field) {
return Option<>(400, "Field `" + default_sorting_field + "` has been declared as a default sorting field, "
"but is not found in the document.");
}
if(has_default_sort_field &&
!document[default_sorting_field].is_number_integer() && !document[default_sorting_field].is_number_float()) {
return Option<>(400, "Default sorting field `" + default_sorting_field + "` must be a single valued numerical field.");
}
if(has_default_sort_field && search_schema.at(default_sorting_field).is_single_float() &&
document[default_sorting_field].get<float>() > std::numeric_limits<float>::max()) {
return Option<>(400, "Default sorting field `" + default_sorting_field + "` exceeds maximum value of a float.");
}
for(const auto& field_pair: search_schema) {
const std::string& field_name = field_pair.first;
@ -317,18 +313,18 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
}
}
} else if(field_pair.second.type == field_types::INT64 && !document[field_name].is_number_integer()) {
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, false);
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, -1);
if(!coerce_op.ok()) {
return coerce_op;
}
} else if(field_pair.second.type == field_types::FLOAT && !document[field_name].is_number()) {
// using `is_number` allows integer to be passed to a float field
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, false);
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, -1);
if(!coerce_op.ok()) {
return coerce_op;
}
} else if(field_pair.second.type == field_types::BOOL && !document[field_name].is_boolean()) {
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, false);
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, -1);
if(!coerce_op.ok()) {
return coerce_op;
}
@ -356,18 +352,18 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
return coerce_op;
}
} else if (field_pair.second.type == field_types::INT64_ARRAY && !item.is_number_integer()) {
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, true);
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, arr_index);
if (!coerce_op.ok()) {
return coerce_op;
}
} else if (field_pair.second.type == field_types::FLOAT_ARRAY && !item.is_number()) {
// we check for `is_number` to allow whole numbers to be passed into float fields
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, true);
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, arr_index);
if (!coerce_op.ok()) {
return coerce_op;
}
} else if (field_pair.second.type == field_types::BOOL_ARRAY && !item.is_boolean()) {
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, true);
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, arr_index);
if (!coerce_op.ok()) {
return coerce_op;
}
@ -1335,7 +1331,8 @@ void Index::run_search(search_args* search_params) {
search_params->searched_queries,
search_params->raw_result_kvs, search_params->override_result_kvs,
search_params->typo_tokens_threshold,
search_params->group_limit, search_params->group_by_fields);
search_params->group_limit, search_params->group_by_fields,
search_params->default_sorting_field);
}
void Index::collate_included_ids(const std::vector<std::string>& q_included_tokens,
@ -1427,7 +1424,8 @@ void Index::search(const std::vector<std::string>& q_include_tokens,
std::vector<std::vector<KV*>> & override_result_kvs,
const size_t typo_tokens_threshold,
const size_t group_limit,
const std::vector<std::string>& group_by_fields) const {
const std::vector<std::string>& group_by_fields,
const std::string& default_sorting_field) const {
std::shared_lock lock(mutex);
@ -1491,27 +1489,22 @@ void Index::search(const std::vector<std::string>& q_include_tokens,
// if a filter is not specified, use the sorting index to generate the list of all document ids
if(filters.empty()) {
std::string all_records_field;
if(default_sorting_field.empty()) {
filter_ids_length = seq_ids.getLength();
filter_ids = seq_ids.uncompress();
} else {
const spp::sparse_hash_map<uint32_t, int64_t> *kvs = sort_index.at(default_sorting_field);
filter_ids_length = kvs->size();
filter_ids = new uint32_t[filter_ids_length];
// get the first non-optional field
for(const auto& kv: sort_schema) {
if(!kv.second.optional && kv.first != sort_field_const::text_match) {
all_records_field = kv.first;
break;
size_t i = 0;
for(const auto& kv: *kvs) {
filter_ids[i++] = kv.first;
}
// ids populated from hash map will not be sorted, but sorting is required for intersection & other ops
std::sort(filter_ids, filter_ids+filter_ids_length);
}
const spp::sparse_hash_map<uint32_t, int64_t> *kvs = sort_index.at(all_records_field);
filter_ids_length = kvs->size();
filter_ids = new uint32_t[filter_ids_length];
size_t i = 0;
for(const auto& kv: *kvs) {
filter_ids[i++] = kv.first;
}
// ids populated from hash map will not be sorted, but sorting is required for intersection & other ops
std::sort(filter_ids, filter_ids+filter_ids_length);
}
if(!curated_ids.empty()) {
@ -1919,7 +1912,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
const size_t group_limit, const std::vector<std::string>& group_by_fields,
uint32_t token_bits) const {
std::vector<uint32_t*> leaf_to_indices;
std::vector<uint32_t *> leaf_to_indices;
for (art_leaf *token_leaf: query_suggestion) {
uint32_t *indices = new uint32_t[result_size];
token_leaf->values->ids.indexOf(result_ids, result_size, indices);
@ -1937,19 +1930,25 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
spp::sparse_hash_map<uint32_t, int64_t> geopoint_distances[3];
for(size_t i = 0; i < sort_fields.size(); i++) {
spp::sparse_hash_map<uint32_t, int64_t> text_match_sentinel_value, seq_id_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> *TEXT_MATCH_SENTINEL = &text_match_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> *SEQ_ID_SENTINEL = &seq_id_sentinel_value;
for (size_t i = 0; i < sort_fields.size(); i++) {
sort_order[i] = 1;
if(sort_fields[i].order == sort_field_const::asc) {
if (sort_fields[i].order == sort_field_const::asc) {
sort_order[i] = -1;
}
if(sort_fields[i].name == sort_field_const::text_match) {
field_values[i] = nullptr;
} else if(sort_schema.at(sort_fields[i].name).is_geopoint()) {
if (sort_fields[i].name == sort_field_const::text_match) {
field_values[i] = TEXT_MATCH_SENTINEL;
} else if (sort_fields[i].name == sort_field_const::seq_id) {
field_values[i] = SEQ_ID_SENTINEL;
} else if (sort_schema.at(sort_fields[i].name).is_geopoint()) {
// we have to populate distances that will be used for match scoring
spp::sparse_hash_map<uint32_t, int64_t>* geopoints = sort_index.at(sort_fields[i].name);
spp::sparse_hash_map<uint32_t, int64_t> *geopoints = sort_index.at(sort_fields[i].name);
for(size_t rindex=0; rindex<result_size; rindex++) {
for (size_t rindex = 0; rindex < result_size; rindex++) {
const uint32_t seq_id = result_ids[rindex];
auto it = geopoints->find(seq_id);
int64_t dist = (it == geopoints->end()) ? INT32_MAX : h3Distance(sort_fields[i].geopoint, it->second);
@ -1964,23 +1963,23 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
//auto begin = std::chrono::high_resolution_clock::now();
for(size_t i=0; i<result_size; i++) {
for (size_t i = 0; i < result_size; i++) {
const uint32_t seq_id = result_ids[i];
uint64_t match_score = 0;
if(query_suggestion.size() <= 1) {
if (query_suggestion.size() <= 1) {
match_score = single_token_match_score;
} else {
std::unordered_map<size_t, std::vector<std::vector<uint16_t>>> array_token_positions;
populate_token_positions(query_suggestion, leaf_to_indices, i, array_token_positions);
for(const auto& kv: array_token_positions) {
const std::vector<std::vector<uint16_t>>& token_positions = kv.second;
if(token_positions.empty()) {
for (const auto& kv: array_token_positions) {
const std::vector<std::vector<uint16_t>> &token_positions = kv.second;
if (token_positions.empty()) {
continue;
}
const Match & match = Match(seq_id, token_positions, false);
const Match &match = Match(seq_id, token_positions, false);
uint64_t this_match_score = match.get_match_score(total_cost);
match_score += this_match_score;
@ -2000,40 +1999,49 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
size_t match_score_index = 0;
// avoiding loop
if(sort_fields.size() > 0) {
if (field_values[0] != nullptr) {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
} else {
if (sort_fields.size() > 0) {
if (field_values[0] == TEXT_MATCH_SENTINEL) {
scores[0] = int64_t(match_score);
match_score_index = 0;
} else if (field_values[0] == SEQ_ID_SENTINEL) {
scores[0] = seq_id;
} else {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
}
if (sort_order[0] == -1) {
scores[0] = -scores[0];
}
}
if(sort_fields.size() > 1) {
if (field_values[1] != nullptr) {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
} else {
if (field_values[1] == TEXT_MATCH_SENTINEL) {
scores[1] = int64_t(match_score);
match_score_index = 1;
} else if (field_values[1] == SEQ_ID_SENTINEL) {
scores[1] = seq_id;
} else {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
}
if (sort_order[1] == -1) {
scores[1] = -scores[1];
}
}
if(sort_fields.size() > 2) {
if(field_values[2] != nullptr) {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
} else {
if(field_values[2] != TEXT_MATCH_SENTINEL) {
scores[2] = int64_t(match_score);
match_score_index = 2;
} else if (field_values[2] == SEQ_ID_SENTINEL) {
scores[2] = seq_id;
} else {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
}
if(sort_order[2] == -1) {
scores[2] = -scores[2];
}
@ -2314,6 +2322,8 @@ Option<uint32_t> Index::remove(const uint32_t seq_id, const nlohmann::json & doc
}
}
seq_ids.remove_value(seq_id);
return Option<uint32_t>(seq_id);
}

View File

@ -780,10 +780,6 @@ TEST_F(CollectionFacetingTest, FacetCountOnSimilarStrings) {
token_ordering::FREQUENCY, true, 10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10).get();
LOG(INFO) << results;
return;
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ(2, results["facet_counts"][0]["counts"].size());

View File

@ -304,6 +304,13 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
ASSERT_EQ(1, coll1->get_collection_id());
ASSERT_EQ(3, coll1->get_sort_fields().size());
// index a document with a bad field value with COERCE_OR_IGNORE setting
auto doc_json = R"({"title": "Unique record.", "max": 25, "scores": [22, "how", 44],
"average": "bad data", "is_valid": true})";
Option<nlohmann::json> add_op = coll1->add(doc_json, CREATE, "", DIRTY_VALUES::COERCE_OR_IGNORE);
ASSERT_TRUE(add_op.ok());
std::unordered_map<std::string, field> schema = collection1->get_schema();
// create a new collection manager to ensure that it restores the records from the disk backed store
@ -324,7 +331,8 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
auto restored_schema = restored_coll->get_schema();
ASSERT_EQ(1, restored_coll->get_collection_id());
ASSERT_EQ(6, restored_coll->get_next_seq_id());
ASSERT_EQ(7, restored_coll->get_next_seq_id());
ASSERT_EQ(7, restored_coll->get_num_documents());
ASSERT_EQ(facet_fields_expected, restored_coll->get_facet_fields());
ASSERT_EQ(3, restored_coll->get_sort_fields().size());
ASSERT_EQ("is_valid", restored_coll->get_sort_fields()[0].name);
@ -347,6 +355,24 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
ASSERT_FALSE(kv.second.optional);
}
// try searching for record with bad data
auto results = restored_coll->search("unique", {"title"}, "", {}, {}, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("Unique record.", results["hits"][0]["document"]["title"].get<std::string>().c_str());
ASSERT_EQ(0, results["hits"][0]["document"].count("average"));
ASSERT_EQ(2, results["hits"][0]["document"]["scores"].size());
ASSERT_EQ(22, results["hits"][0]["document"]["scores"][0]);
ASSERT_EQ(44, results["hits"][0]["document"]["scores"][1]);
// try sorting on `average`, a field that not all records have
ASSERT_EQ(7, restored_coll->get_num_documents());
sort_fields = { sort_by("average", "DESC") };
results = restored_coll->search("*", {"title"}, "", {}, {sort_fields}, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(7, results["hits"].size());
collectionManager.drop_collection("coll1");
collectionManager2.drop_collection("coll1");
}

View File

@ -129,19 +129,114 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
std::vector<sort_by> sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
Option<Collection*> collection_op = collectionManager.create_collection("sample_collection", 4, fields, "name");
EXPECT_FALSE(collection_op.ok());
EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
ASSERT_FALSE(collection_op.ok());
ASSERT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
collectionManager.drop_collection("sample_collection");
// Default sorting field must exist as a field in schema
sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
collection_op = collectionManager.create_collection("sample_collection", 4, fields, "NOT-DEFINED");
EXPECT_FALSE(collection_op.ok());
EXPECT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error());
ASSERT_FALSE(collection_op.ok());
ASSERT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error());
collectionManager.drop_collection("sample_collection");
}
TEST_F(CollectionSortingTest, NoDefaultSortingField) {
Collection *coll1;
std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl");
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 4, fields).get();
}
std::string json_line;
while (std::getline(infile, json_line)) {
coll1->add(json_line);
}
infile.close();
// without a default sorting field, matches should be sorted by (text_match, seq_id)
auto results = coll1->search("rocket", {"title"}, "", {}, {}, 1, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(5, results["found"]);
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ(24, results["out_of"]);
std::vector<std::string> ids = {"16", "15", "7", "0", "22"};
for(size_t i=0; i < results["hits"].size(); i++) {
ASSERT_EQ(ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// try removing a document and doing wildcard (tests the seq_id array used for wildcard searches)
auto remove_op = coll1->remove("0");
ASSERT_TRUE(remove_op.ok());
results = coll1->search("*", {}, "", {}, {}, 1, 30, 1, FREQUENCY, false).get();
ASSERT_EQ(23, results["found"]);
ASSERT_EQ(23, results["hits"].size());
ASSERT_EQ(23, results["out_of"]);
for(size_t i=23; i >= 1; i--) {
std::string doc_id = (i == 4) ? "foo" : std::to_string(i);
ASSERT_EQ(doc_id, results["hits"][23 - i]["document"]["id"].get<std::string>());
}
}
TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField) {
// when no default sorting field is provided, tokens must be ordered on frequency
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields).get();
}
// since only top 10 tokens are fetched for prefixes, the "end" should not show up in the results
std::vector<std::string> tokens = {
"enter", "elephant", "enamel", "ercot", "enyzme", "energy",
"epoch", "epyc", "express", "everest", "end"
};
for(size_t i = 0; i < tokens.size(); i++) {
size_t num_repeat = tokens.size() - i;
std::string title = tokens[i];
for(size_t j = 0; j < num_repeat; j++) {
nlohmann::json doc;
doc["title"] = title;
doc["points"] = num_repeat;
coll1->add(doc.dump());
}
}
auto results = coll1->search("e", {"title"}, "", {}, {}, 0, 100, 1, NOT_SET, true).get();
// 11 + 10 + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2
ASSERT_EQ(65, results["found"]);
// we have to ensure that no result contains the word "end" since it occurs least number of times
bool found_end = false;
for(auto& res: results["hits"].items()) {
if(res.value()["document"]["title"] == "end") {
found_end = true;
}
}
ASSERT_FALSE(found_end);
}
TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) {
Collection *coll_mul_fields;

View File

@ -480,8 +480,6 @@ TEST_F(CollectionTest, WildcardQuery) {
nlohmann::json results = collection->search("*", query_fields, "points:>0", {}, sort_fields, 0, 3, 1, FREQUENCY,
false).get();
LOG(INFO) << results;
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(25, results["found"].get<uint32_t>());
@ -1639,10 +1637,10 @@ TEST_F(CollectionTest, IndexingWithBadData) {
const Option<nlohmann::json> & empty_facet_field_op = sample_collection->add(doc_str);
ASSERT_TRUE(empty_facet_field_op.ok());
doc_str = "{\"name\": \"foo\", \"age\": \"34\", \"tags\": [], \"average\": 34 }";
doc_str = "{\"name\": \"foo\", \"age\": [\"34\"], \"tags\": [], \"average\": 34 }";
const Option<nlohmann::json> & bad_default_sorting_field_op1 = sample_collection->add(doc_str);
ASSERT_FALSE(bad_default_sorting_field_op1.ok());
ASSERT_STREQ("Default sorting field `age` must be a single valued numerical field.", bad_default_sorting_field_op1.error().c_str());
ASSERT_STREQ("Field `age` must be an int32.", bad_default_sorting_field_op1.error().c_str());
doc_str = "{\"name\": \"foo\", \"tags\": [], \"average\": 34 }";
const Option<nlohmann::json> & bad_default_sorting_field_op3 = sample_collection->add(doc_str);
@ -2691,8 +2689,6 @@ TEST_F(CollectionTest, MultiFieldRelevance) {
auto results = coll1->search("Dustin Kensrue Down There by the Train",
{"title", "artist"}, "", {}, {}, 0, 10, 1, FREQUENCY).get();
LOG(INFO) << results;
ASSERT_EQ(3, results["found"].get<size_t>());
ASSERT_EQ(3, results["hits"].size());