mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 06:02:26 +08:00
Make default sorting field optional.
This commit is contained in:
parent
222ab345be
commit
d2a825799b
@ -37,9 +37,9 @@ public:
|
||||
}
|
||||
|
||||
// len determines length of output buffer (default: length of input)
|
||||
uint32_t* uncompress(uint32_t len=0);
|
||||
uint32_t* uncompress(uint32_t len=0) const;
|
||||
|
||||
uint32_t getSizeInBytes();
|
||||
|
||||
uint32_t getLength();
|
||||
uint32_t getLength() const;
|
||||
};
|
@ -122,6 +122,8 @@ typedef struct {
|
||||
} art_document;
|
||||
|
||||
enum token_ordering {
|
||||
NOT_SET,
|
||||
|
||||
FREQUENCY,
|
||||
MAX_SCORE
|
||||
};
|
||||
|
@ -127,7 +127,7 @@ public:
|
||||
|
||||
Option<Collection*> create_collection(const std::string& name, const size_t num_memory_shards,
|
||||
const std::vector<field> & fields,
|
||||
const std::string & default_sorting_field,
|
||||
const std::string & default_sorting_field="",
|
||||
const uint64_t created_at = static_cast<uint64_t>(std::time(nullptr)),
|
||||
const bool index_all_fields = false);
|
||||
|
||||
|
@ -164,8 +164,9 @@ struct field {
|
||||
}
|
||||
|
||||
static Option<bool> fields_to_json_fields(const std::vector<field> & fields,
|
||||
const std::string & default_sorting_field, nlohmann::json& fields_json,
|
||||
bool& found_default_sorting_field) {
|
||||
const std::string & default_sorting_field, nlohmann::json& fields_json) {
|
||||
bool found_default_sorting_field = false;
|
||||
|
||||
for(const field & field: fields) {
|
||||
nlohmann::json field_val;
|
||||
field_val[fields::name] = field.name;
|
||||
@ -197,6 +198,11 @@ struct field {
|
||||
}
|
||||
}
|
||||
|
||||
if(!default_sorting_field.empty() && !found_default_sorting_field) {
|
||||
return Option<bool>(400, "Default sorting field is defined as `" + default_sorting_field +
|
||||
"` but is not found in the schema.");
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
};
|
||||
@ -276,6 +282,7 @@ namespace sort_field_const {
|
||||
static const std::string asc = "ASC";
|
||||
static const std::string desc = "DESC";
|
||||
static const std::string text_match = "_text_match";
|
||||
static const std::string seq_id = "_seq_id";
|
||||
}
|
||||
|
||||
struct sort_by {
|
||||
|
@ -56,6 +56,7 @@ struct search_args {
|
||||
size_t typo_tokens_threshold;
|
||||
std::vector<std::string> group_by_fields;
|
||||
size_t group_limit;
|
||||
std::string default_sorting_field;
|
||||
size_t all_result_ids_len;
|
||||
spp::sparse_hash_set<uint64_t> groups_processed;
|
||||
std::vector<std::vector<art_leaf*>> searched_queries;
|
||||
@ -76,14 +77,15 @@ struct search_args {
|
||||
std::vector<sort_by> sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values,
|
||||
size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix,
|
||||
size_t drop_tokens_threshold, size_t typo_tokens_threshold,
|
||||
const std::vector<std::string>& group_by_fields, size_t group_limit):
|
||||
const std::vector<std::string>& group_by_fields, size_t group_limit,
|
||||
const std::string& default_sorting_field):
|
||||
q_include_tokens(q_include_tokens), q_exclude_tokens(q_exclude_tokens), q_synonyms(q_synonyms),
|
||||
search_fields(search_fields), filters(filters), facets(facets),
|
||||
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
|
||||
facet_query(facet_query), num_typos(num_typos), max_facet_values(max_facet_values), per_page(per_page),
|
||||
page(page), token_order(token_order), prefix(prefix),
|
||||
drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold),
|
||||
group_by_fields(group_by_fields), group_limit(group_limit),
|
||||
group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field),
|
||||
all_result_ids_len(0) {
|
||||
|
||||
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
|
||||
@ -169,6 +171,9 @@ private:
|
||||
// sort_field => (seq_id => value)
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t>*> sort_index;
|
||||
|
||||
// this is used for wildcard queries
|
||||
sorted_array seq_ids;
|
||||
|
||||
StringUtils string_utils;
|
||||
|
||||
// Internal utility functions
|
||||
@ -349,7 +354,8 @@ public:
|
||||
std::vector<std::vector<KV*>> & override_result_kvs,
|
||||
const size_t typo_tokens_threshold,
|
||||
const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields) const;
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const std::string& default_sorting_field) const;
|
||||
|
||||
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document);
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include "array_base.h"
|
||||
|
||||
uint32_t* array_base::uncompress(uint32_t len) {
|
||||
uint32_t* array_base::uncompress(uint32_t len) const {
|
||||
uint32_t actual_len = std::max(len, length);
|
||||
uint32_t *out = new uint32_t[actual_len];
|
||||
for_uncompress(in, out, length);
|
||||
@ -11,6 +11,6 @@ uint32_t array_base::getSizeInBytes() {
|
||||
return size_bytes;
|
||||
}
|
||||
|
||||
uint32_t array_base::getLength() {
|
||||
uint32_t array_base::getLength() const {
|
||||
return length;
|
||||
}
|
||||
|
@ -491,7 +491,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
const std::string & simple_filter_query, const std::vector<std::string>& facet_fields,
|
||||
const std::vector<sort_by> & sort_fields, const int num_typos,
|
||||
const size_t per_page, const size_t page,
|
||||
const token_ordering token_order, const bool prefix,
|
||||
token_ordering token_order, const bool prefix,
|
||||
const size_t drop_tokens_threshold,
|
||||
const spp::sparse_hash_set<std::string> & include_fields,
|
||||
const spp::sparse_hash_set<std::string> & exclude_fields,
|
||||
@ -755,7 +755,11 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
*/
|
||||
if(sort_fields_std.empty()) {
|
||||
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
|
||||
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
|
||||
if(!default_sorting_field.empty()) {
|
||||
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
|
||||
} else {
|
||||
sort_fields_std.emplace_back(sort_field_const::seq_id, sort_field_const::desc);
|
||||
}
|
||||
}
|
||||
|
||||
bool found_match_score = false;
|
||||
@ -801,6 +805,14 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
max_hits = std::min(std::max((page * per_page), max_hits), get_num_documents());
|
||||
}
|
||||
|
||||
if(token_order == NOT_SET) {
|
||||
if(default_sorting_field.empty()) {
|
||||
token_order = FREQUENCY;
|
||||
} else {
|
||||
token_order = MAX_SCORE;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<art_leaf*>> searched_queries; // search queries used for generating the results
|
||||
std::vector<std::vector<KV*>> raw_result_kvs;
|
||||
std::vector<std::vector<KV*>> override_result_kvs;
|
||||
@ -833,7 +845,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
sort_fields_std, facet_query, num_typos, max_facet_values, max_hits,
|
||||
per_page, page, token_order, prefix,
|
||||
drop_tokens_threshold, typo_tokens_threshold,
|
||||
group_by_fields, group_limit);
|
||||
group_by_fields, group_limit, default_sorting_field);
|
||||
|
||||
search_args_vec.push_back(search_params);
|
||||
|
||||
@ -2264,11 +2276,9 @@ Option<bool> Collection::check_and_update_schema(nlohmann::json& document) {
|
||||
|
||||
try {
|
||||
collection_meta = nlohmann::json::parse(coll_meta_json);
|
||||
bool found_default_sorting_field = false;
|
||||
nlohmann::json fields_json = nlohmann::json::array();;
|
||||
|
||||
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json,
|
||||
found_default_sorting_field);
|
||||
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json);
|
||||
|
||||
if(!fields_json_op.ok()) {
|
||||
return Option<bool>(fields_json_op.code(), fields_json_op.error());
|
||||
|
@ -304,7 +304,7 @@ bool CollectionManager::auth_key_matches(const std::string& auth_key_sent,
|
||||
Option<Collection*> CollectionManager::create_collection(const std::string& name,
|
||||
const size_t num_memory_shards,
|
||||
const std::vector<field> & fields,
|
||||
const std::string & default_sorting_field,
|
||||
const std::string& default_sorting_field,
|
||||
const uint64_t created_at,
|
||||
const bool index_all_fields) {
|
||||
std::unique_lock lock(mutex);
|
||||
@ -313,21 +313,14 @@ Option<Collection*> CollectionManager::create_collection(const std::string& name
|
||||
return Option<Collection*>(409, std::string("A collection with name `") + name + "` already exists.");
|
||||
}
|
||||
|
||||
bool found_default_sorting_field = false;
|
||||
nlohmann::json fields_json = nlohmann::json::array();;
|
||||
|
||||
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json,
|
||||
found_default_sorting_field);
|
||||
Option<bool> fields_json_op = field::fields_to_json_fields(fields, default_sorting_field, fields_json);
|
||||
|
||||
if(!fields_json_op.ok()) {
|
||||
return Option<Collection*>(fields_json_op.code(), fields_json_op.error());
|
||||
}
|
||||
|
||||
if(!found_default_sorting_field) {
|
||||
return Option<Collection*>(400, "Default sorting field is defined as `" + default_sorting_field +
|
||||
"` but is not found in the schema.");
|
||||
}
|
||||
|
||||
nlohmann::json collection_meta;
|
||||
collection_meta[Collection::COLLECTION_NAME_KEY] = name;
|
||||
collection_meta[Collection::COLLECTION_ID_KEY] = next_collection_id.load();
|
||||
@ -765,12 +758,16 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const size_t drop_tokens_threshold = (size_t) std::stoi(req_params[DROP_TOKENS_THRESHOLD]);
|
||||
const size_t typo_tokens_threshold = (size_t) std::stoi(req_params[TYPO_TOKENS_THRESHOLD]);
|
||||
|
||||
if(req_params.count(RANK_TOKENS_BY) == 0) {
|
||||
req_params[RANK_TOKENS_BY] = "DEFAULT_SORTING_FIELD";
|
||||
}
|
||||
token_ordering token_order = NOT_SET;
|
||||
|
||||
StringUtils::toupper(req_params[RANK_TOKENS_BY]);
|
||||
token_ordering token_order = (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") ? MAX_SCORE : FREQUENCY;
|
||||
if(req_params.count(RANK_TOKENS_BY) != 0) {
|
||||
StringUtils::toupper(req_params[RANK_TOKENS_BY]);
|
||||
if (req_params[RANK_TOKENS_BY] == "DEFAULT_SORTING_FIELD") {
|
||||
token_order = MAX_SCORE;
|
||||
} else if(req_params[RANK_TOKENS_BY] == "FREQUENCY") {
|
||||
token_order = FREQUENCY;
|
||||
}
|
||||
}
|
||||
|
||||
Option<nlohmann::json> result_op = collection->search(req_params[QUERY], search_fields, filter_str, facet_fields,
|
||||
sort_fields, std::stoi(req_params[NUM_TYPOS]),
|
||||
|
@ -108,8 +108,7 @@ bool post_create_collection(http_req & req, http_res & res) {
|
||||
const char* DEFAULT_SORTING_FIELD = "default_sorting_field";
|
||||
|
||||
if(req_json.count(DEFAULT_SORTING_FIELD) == 0) {
|
||||
res.set_400("Parameter `default_sorting_field` is required.");
|
||||
return false;
|
||||
req_json[DEFAULT_SORTING_FIELD] = "";
|
||||
}
|
||||
|
||||
if(!req_json[DEFAULT_SORTING_FIELD].is_string()) {
|
||||
|
164
src/index.cpp
164
src/index.cpp
@ -67,16 +67,14 @@ Index::~Index() {
|
||||
int64_t Index::get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field) {
|
||||
int64_t points = 0;
|
||||
|
||||
if(!default_sorting_field.empty()) {
|
||||
if(document[default_sorting_field].is_number_float()) {
|
||||
// serialize float to an integer and reverse the inverted range
|
||||
float n = document[default_sorting_field];
|
||||
memcpy(&points, &n, sizeof(int32_t));
|
||||
points ^= ((points >> (std::numeric_limits<int32_t>::digits - 1)) | INT32_MIN);
|
||||
points = -1 * (INT32_MAX - points);
|
||||
} else {
|
||||
points = document[default_sorting_field];
|
||||
}
|
||||
if(document[default_sorting_field].is_number_float()) {
|
||||
// serialize float to an integer and reverse the inverted range
|
||||
float n = document[default_sorting_field];
|
||||
memcpy(&points, &n, sizeof(int32_t));
|
||||
points ^= ((points >> (std::numeric_limits<int32_t>::digits - 1)) | INT32_MIN);
|
||||
points = -1 * (INT32_MAX - points);
|
||||
} else {
|
||||
points = document[default_sorting_field];
|
||||
}
|
||||
|
||||
return points;
|
||||
@ -99,12 +97,20 @@ Option<uint32_t> Index::index_in_memory(const nlohmann::json &document, uint32_t
|
||||
|
||||
int64_t points = 0;
|
||||
|
||||
if(is_update && document.count(default_sorting_field) == 0) {
|
||||
points = sort_index[default_sorting_field]->at(seq_id);
|
||||
if(document.count(default_sorting_field) == 0) {
|
||||
if(sort_index.count(default_sorting_field) != 0 && sort_index[default_sorting_field]->count(seq_id)) {
|
||||
points = sort_index[default_sorting_field]->at(seq_id);
|
||||
} else {
|
||||
points = INT64_MIN;
|
||||
}
|
||||
} else {
|
||||
points = get_points_from_doc(document, default_sorting_field);
|
||||
}
|
||||
|
||||
if(!is_update) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, size_t> facet_to_id;
|
||||
size_t i_facet = 0;
|
||||
for(const auto & facet: facet_schema) {
|
||||
@ -266,23 +272,13 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
|
||||
bool index_all_fields,
|
||||
const DIRTY_VALUES& dirty_values) {
|
||||
|
||||
bool has_default_sort_field = (document.count(default_sorting_field) != 0);
|
||||
bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0);
|
||||
|
||||
if(!has_default_sort_field && !is_update) {
|
||||
if(!is_update && missing_default_sort_field) {
|
||||
return Option<>(400, "Field `" + default_sorting_field + "` has been declared as a default sorting field, "
|
||||
"but is not found in the document.");
|
||||
}
|
||||
|
||||
if(has_default_sort_field &&
|
||||
!document[default_sorting_field].is_number_integer() && !document[default_sorting_field].is_number_float()) {
|
||||
return Option<>(400, "Default sorting field `" + default_sorting_field + "` must be a single valued numerical field.");
|
||||
}
|
||||
|
||||
if(has_default_sort_field && search_schema.at(default_sorting_field).is_single_float() &&
|
||||
document[default_sorting_field].get<float>() > std::numeric_limits<float>::max()) {
|
||||
return Option<>(400, "Default sorting field `" + default_sorting_field + "` exceeds maximum value of a float.");
|
||||
}
|
||||
|
||||
for(const auto& field_pair: search_schema) {
|
||||
const std::string& field_name = field_pair.first;
|
||||
|
||||
@ -317,18 +313,18 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
|
||||
}
|
||||
}
|
||||
} else if(field_pair.second.type == field_types::INT64 && !document[field_name].is_number_integer()) {
|
||||
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, false);
|
||||
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, -1);
|
||||
if(!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
} else if(field_pair.second.type == field_types::FLOAT && !document[field_name].is_number()) {
|
||||
// using `is_number` allows integer to be passed to a float field
|
||||
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, false);
|
||||
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, -1);
|
||||
if(!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
} else if(field_pair.second.type == field_types::BOOL && !document[field_name].is_boolean()) {
|
||||
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, false);
|
||||
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, -1);
|
||||
if(!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
@ -356,18 +352,18 @@ Option<uint32_t> Index::validate_index_in_memory(nlohmann::json& document, uint3
|
||||
return coerce_op;
|
||||
}
|
||||
} else if (field_pair.second.type == field_types::INT64_ARRAY && !item.is_number_integer()) {
|
||||
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, true);
|
||||
Option<uint32_t> coerce_op = coerce_int64_t(dirty_values, document, field_name, arr_index);
|
||||
if (!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
} else if (field_pair.second.type == field_types::FLOAT_ARRAY && !item.is_number()) {
|
||||
// we check for `is_number` to allow whole numbers to be passed into float fields
|
||||
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, true);
|
||||
Option<uint32_t> coerce_op = coerce_float(dirty_values, document, field_name, arr_index);
|
||||
if (!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
} else if (field_pair.second.type == field_types::BOOL_ARRAY && !item.is_boolean()) {
|
||||
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, true);
|
||||
Option<uint32_t> coerce_op = coerce_bool(dirty_values, document, field_name, arr_index);
|
||||
if (!coerce_op.ok()) {
|
||||
return coerce_op;
|
||||
}
|
||||
@ -1335,7 +1331,8 @@ void Index::run_search(search_args* search_params) {
|
||||
search_params->searched_queries,
|
||||
search_params->raw_result_kvs, search_params->override_result_kvs,
|
||||
search_params->typo_tokens_threshold,
|
||||
search_params->group_limit, search_params->group_by_fields);
|
||||
search_params->group_limit, search_params->group_by_fields,
|
||||
search_params->default_sorting_field);
|
||||
}
|
||||
|
||||
void Index::collate_included_ids(const std::vector<std::string>& q_included_tokens,
|
||||
@ -1427,7 +1424,8 @@ void Index::search(const std::vector<std::string>& q_include_tokens,
|
||||
std::vector<std::vector<KV*>> & override_result_kvs,
|
||||
const size_t typo_tokens_threshold,
|
||||
const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields) const {
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const std::string& default_sorting_field) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -1491,27 +1489,22 @@ void Index::search(const std::vector<std::string>& q_include_tokens,
|
||||
|
||||
// if a filter is not specified, use the sorting index to generate the list of all document ids
|
||||
if(filters.empty()) {
|
||||
std::string all_records_field;
|
||||
if(default_sorting_field.empty()) {
|
||||
filter_ids_length = seq_ids.getLength();
|
||||
filter_ids = seq_ids.uncompress();
|
||||
} else {
|
||||
const spp::sparse_hash_map<uint32_t, int64_t> *kvs = sort_index.at(default_sorting_field);
|
||||
filter_ids_length = kvs->size();
|
||||
filter_ids = new uint32_t[filter_ids_length];
|
||||
|
||||
// get the first non-optional field
|
||||
for(const auto& kv: sort_schema) {
|
||||
if(!kv.second.optional && kv.first != sort_field_const::text_match) {
|
||||
all_records_field = kv.first;
|
||||
break;
|
||||
size_t i = 0;
|
||||
for(const auto& kv: *kvs) {
|
||||
filter_ids[i++] = kv.first;
|
||||
}
|
||||
|
||||
// ids populated from hash map will not be sorted, but sorting is required for intersection & other ops
|
||||
std::sort(filter_ids, filter_ids+filter_ids_length);
|
||||
}
|
||||
|
||||
const spp::sparse_hash_map<uint32_t, int64_t> *kvs = sort_index.at(all_records_field);
|
||||
filter_ids_length = kvs->size();
|
||||
filter_ids = new uint32_t[filter_ids_length];
|
||||
|
||||
size_t i = 0;
|
||||
for(const auto& kv: *kvs) {
|
||||
filter_ids[i++] = kv.first;
|
||||
}
|
||||
|
||||
// ids populated from hash map will not be sorted, but sorting is required for intersection & other ops
|
||||
std::sort(filter_ids, filter_ids+filter_ids_length);
|
||||
}
|
||||
|
||||
if(!curated_ids.empty()) {
|
||||
@ -1919,7 +1912,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
const size_t group_limit, const std::vector<std::string>& group_by_fields,
|
||||
uint32_t token_bits) const {
|
||||
|
||||
std::vector<uint32_t*> leaf_to_indices;
|
||||
std::vector<uint32_t *> leaf_to_indices;
|
||||
for (art_leaf *token_leaf: query_suggestion) {
|
||||
uint32_t *indices = new uint32_t[result_size];
|
||||
token_leaf->values->ids.indexOf(result_ids, result_size, indices);
|
||||
@ -1937,19 +1930,25 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
|
||||
spp::sparse_hash_map<uint32_t, int64_t> geopoint_distances[3];
|
||||
|
||||
for(size_t i = 0; i < sort_fields.size(); i++) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t> text_match_sentinel_value, seq_id_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> *TEXT_MATCH_SENTINEL = &text_match_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> *SEQ_ID_SENTINEL = &seq_id_sentinel_value;
|
||||
|
||||
for (size_t i = 0; i < sort_fields.size(); i++) {
|
||||
sort_order[i] = 1;
|
||||
if(sort_fields[i].order == sort_field_const::asc) {
|
||||
if (sort_fields[i].order == sort_field_const::asc) {
|
||||
sort_order[i] = -1;
|
||||
}
|
||||
|
||||
if(sort_fields[i].name == sort_field_const::text_match) {
|
||||
field_values[i] = nullptr;
|
||||
} else if(sort_schema.at(sort_fields[i].name).is_geopoint()) {
|
||||
if (sort_fields[i].name == sort_field_const::text_match) {
|
||||
field_values[i] = TEXT_MATCH_SENTINEL;
|
||||
} else if (sort_fields[i].name == sort_field_const::seq_id) {
|
||||
field_values[i] = SEQ_ID_SENTINEL;
|
||||
} else if (sort_schema.at(sort_fields[i].name).is_geopoint()) {
|
||||
// we have to populate distances that will be used for match scoring
|
||||
spp::sparse_hash_map<uint32_t, int64_t>* geopoints = sort_index.at(sort_fields[i].name);
|
||||
spp::sparse_hash_map<uint32_t, int64_t> *geopoints = sort_index.at(sort_fields[i].name);
|
||||
|
||||
for(size_t rindex=0; rindex<result_size; rindex++) {
|
||||
for (size_t rindex = 0; rindex < result_size; rindex++) {
|
||||
const uint32_t seq_id = result_ids[rindex];
|
||||
auto it = geopoints->find(seq_id);
|
||||
int64_t dist = (it == geopoints->end()) ? INT32_MAX : h3Distance(sort_fields[i].geopoint, it->second);
|
||||
@ -1964,23 +1963,23 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
|
||||
//auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(size_t i=0; i<result_size; i++) {
|
||||
for (size_t i = 0; i < result_size; i++) {
|
||||
const uint32_t seq_id = result_ids[i];
|
||||
|
||||
uint64_t match_score = 0;
|
||||
|
||||
if(query_suggestion.size() <= 1) {
|
||||
if (query_suggestion.size() <= 1) {
|
||||
match_score = single_token_match_score;
|
||||
} else {
|
||||
std::unordered_map<size_t, std::vector<std::vector<uint16_t>>> array_token_positions;
|
||||
populate_token_positions(query_suggestion, leaf_to_indices, i, array_token_positions);
|
||||
|
||||
for(const auto& kv: array_token_positions) {
|
||||
const std::vector<std::vector<uint16_t>>& token_positions = kv.second;
|
||||
if(token_positions.empty()) {
|
||||
for (const auto& kv: array_token_positions) {
|
||||
const std::vector<std::vector<uint16_t>> &token_positions = kv.second;
|
||||
if (token_positions.empty()) {
|
||||
continue;
|
||||
}
|
||||
const Match & match = Match(seq_id, token_positions, false);
|
||||
const Match &match = Match(seq_id, token_positions, false);
|
||||
uint64_t this_match_score = match.get_match_score(total_cost);
|
||||
|
||||
match_score += this_match_score;
|
||||
@ -2000,40 +1999,49 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
size_t match_score_index = 0;
|
||||
|
||||
// avoiding loop
|
||||
if(sort_fields.size() > 0) {
|
||||
if (field_values[0] != nullptr) {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
} else {
|
||||
if (sort_fields.size() > 0) {
|
||||
if (field_values[0] == TEXT_MATCH_SENTINEL) {
|
||||
scores[0] = int64_t(match_score);
|
||||
match_score_index = 0;
|
||||
} else if (field_values[0] == SEQ_ID_SENTINEL) {
|
||||
scores[0] = seq_id;
|
||||
} else {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
}
|
||||
if (sort_order[0] == -1) {
|
||||
scores[0] = -scores[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(sort_fields.size() > 1) {
|
||||
if (field_values[1] != nullptr) {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
} else {
|
||||
if (field_values[1] == TEXT_MATCH_SENTINEL) {
|
||||
scores[1] = int64_t(match_score);
|
||||
match_score_index = 1;
|
||||
} else if (field_values[1] == SEQ_ID_SENTINEL) {
|
||||
scores[1] = seq_id;
|
||||
} else {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
}
|
||||
|
||||
if (sort_order[1] == -1) {
|
||||
scores[1] = -scores[1];
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 2) {
|
||||
if(field_values[2] != nullptr) {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
} else {
|
||||
if(field_values[2] != TEXT_MATCH_SENTINEL) {
|
||||
scores[2] = int64_t(match_score);
|
||||
match_score_index = 2;
|
||||
} else if (field_values[2] == SEQ_ID_SENTINEL) {
|
||||
scores[2] = seq_id;
|
||||
} else {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
}
|
||||
|
||||
if(sort_order[2] == -1) {
|
||||
scores[2] = -scores[2];
|
||||
}
|
||||
@ -2314,6 +2322,8 @@ Option<uint32_t> Index::remove(const uint32_t seq_id, const nlohmann::json & doc
|
||||
}
|
||||
}
|
||||
|
||||
seq_ids.remove_value(seq_id);
|
||||
|
||||
return Option<uint32_t>(seq_id);
|
||||
}
|
||||
|
||||
|
@ -780,10 +780,6 @@ TEST_F(CollectionFacetingTest, FacetCountOnSimilarStrings) {
|
||||
token_ordering::FREQUENCY, true, 10, spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10).get();
|
||||
|
||||
LOG(INFO) << results;
|
||||
|
||||
return;
|
||||
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ASSERT_EQ(2, results["facet_counts"][0]["counts"].size());
|
||||
|
||||
|
@ -304,6 +304,13 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
|
||||
ASSERT_EQ(1, coll1->get_collection_id());
|
||||
ASSERT_EQ(3, coll1->get_sort_fields().size());
|
||||
|
||||
// index a document with a bad field value with COERCE_OR_IGNORE setting
|
||||
auto doc_json = R"({"title": "Unique record.", "max": 25, "scores": [22, "how", 44],
|
||||
"average": "bad data", "is_valid": true})";
|
||||
|
||||
Option<nlohmann::json> add_op = coll1->add(doc_json, CREATE, "", DIRTY_VALUES::COERCE_OR_IGNORE);
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
std::unordered_map<std::string, field> schema = collection1->get_schema();
|
||||
|
||||
// create a new collection manager to ensure that it restores the records from the disk backed store
|
||||
@ -324,7 +331,8 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
|
||||
auto restored_schema = restored_coll->get_schema();
|
||||
|
||||
ASSERT_EQ(1, restored_coll->get_collection_id());
|
||||
ASSERT_EQ(6, restored_coll->get_next_seq_id());
|
||||
ASSERT_EQ(7, restored_coll->get_next_seq_id());
|
||||
ASSERT_EQ(7, restored_coll->get_num_documents());
|
||||
ASSERT_EQ(facet_fields_expected, restored_coll->get_facet_fields());
|
||||
ASSERT_EQ(3, restored_coll->get_sort_fields().size());
|
||||
ASSERT_EQ("is_valid", restored_coll->get_sort_fields()[0].name);
|
||||
@ -347,6 +355,24 @@ TEST_F(CollectionManagerTest, RestoreAutoSchemaDocsOnRestart) {
|
||||
ASSERT_FALSE(kv.second.optional);
|
||||
}
|
||||
|
||||
// try searching for record with bad data
|
||||
auto results = restored_coll->search("unique", {"title"}, "", {}, {}, 0, 10, 1, FREQUENCY, false).get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_STREQ("Unique record.", results["hits"][0]["document"]["title"].get<std::string>().c_str());
|
||||
ASSERT_EQ(0, results["hits"][0]["document"].count("average"));
|
||||
ASSERT_EQ(2, results["hits"][0]["document"]["scores"].size());
|
||||
ASSERT_EQ(22, results["hits"][0]["document"]["scores"][0]);
|
||||
ASSERT_EQ(44, results["hits"][0]["document"]["scores"][1]);
|
||||
|
||||
// try sorting on `average`, a field that not all records have
|
||||
ASSERT_EQ(7, restored_coll->get_num_documents());
|
||||
|
||||
sort_fields = { sort_by("average", "DESC") };
|
||||
results = restored_coll->search("*", {"title"}, "", {}, {sort_fields}, 0, 10, 1, FREQUENCY, false).get();
|
||||
|
||||
ASSERT_EQ(7, results["hits"].size());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
collectionManager2.drop_collection("coll1");
|
||||
}
|
||||
|
@ -129,19 +129,114 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
|
||||
Option<Collection*> collection_op = collectionManager.create_collection("sample_collection", 4, fields, "name");
|
||||
EXPECT_FALSE(collection_op.ok());
|
||||
EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
|
||||
ASSERT_FALSE(collection_op.ok());
|
||||
ASSERT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
|
||||
collectionManager.drop_collection("sample_collection");
|
||||
|
||||
// Default sorting field must exist as a field in schema
|
||||
|
||||
sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
collection_op = collectionManager.create_collection("sample_collection", 4, fields, "NOT-DEFINED");
|
||||
EXPECT_FALSE(collection_op.ok());
|
||||
EXPECT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error());
|
||||
ASSERT_FALSE(collection_op.ok());
|
||||
ASSERT_EQ("Default sorting field is defined as `NOT-DEFINED` but is not found in the schema.", collection_op.error());
|
||||
collectionManager.drop_collection("sample_collection");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, NoDefaultSortingField) {
|
||||
Collection *coll1;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl");
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false)};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 4, fields).get();
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
coll1->add(json_line);
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
// without a default sorting field, matches should be sorted by (text_match, seq_id)
|
||||
auto results = coll1->search("rocket", {"title"}, "", {}, {}, 1, 10, 1, FREQUENCY, false).get();
|
||||
|
||||
ASSERT_EQ(5, results["found"]);
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
ASSERT_EQ(24, results["out_of"]);
|
||||
|
||||
std::vector<std::string> ids = {"16", "15", "7", "0", "22"};
|
||||
|
||||
for(size_t i=0; i < results["hits"].size(); i++) {
|
||||
ASSERT_EQ(ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// try removing a document and doing wildcard (tests the seq_id array used for wildcard searches)
|
||||
auto remove_op = coll1->remove("0");
|
||||
ASSERT_TRUE(remove_op.ok());
|
||||
|
||||
results = coll1->search("*", {}, "", {}, {}, 1, 30, 1, FREQUENCY, false).get();
|
||||
|
||||
ASSERT_EQ(23, results["found"]);
|
||||
ASSERT_EQ(23, results["hits"].size());
|
||||
ASSERT_EQ(23, results["out_of"]);
|
||||
|
||||
for(size_t i=23; i >= 1; i--) {
|
||||
std::string doc_id = (i == 4) ? "foo" : std::to_string(i);
|
||||
ASSERT_EQ(doc_id, results["hits"][23 - i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField) {
|
||||
// when no default sorting field is provided, tokens must be ordered on frequency
|
||||
Collection *coll1;
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false)};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields).get();
|
||||
}
|
||||
|
||||
// since only top 10 tokens are fetched for prefixes, the "end" should not show up in the results
|
||||
std::vector<std::string> tokens = {
|
||||
"enter", "elephant", "enamel", "ercot", "enyzme", "energy",
|
||||
"epoch", "epyc", "express", "everest", "end"
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < tokens.size(); i++) {
|
||||
size_t num_repeat = tokens.size() - i;
|
||||
|
||||
std::string title = tokens[i];
|
||||
|
||||
for(size_t j = 0; j < num_repeat; j++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = title;
|
||||
doc["points"] = num_repeat;
|
||||
coll1->add(doc.dump());
|
||||
}
|
||||
}
|
||||
|
||||
auto results = coll1->search("e", {"title"}, "", {}, {}, 0, 100, 1, NOT_SET, true).get();
|
||||
|
||||
// 11 + 10 + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2
|
||||
ASSERT_EQ(65, results["found"]);
|
||||
|
||||
// we have to ensure that no result contains the word "end" since it occurs least number of times
|
||||
bool found_end = false;
|
||||
for(auto& res: results["hits"].items()) {
|
||||
if(res.value()["document"]["title"] == "end") {
|
||||
found_end = true;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_FALSE(found_end);
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) {
|
||||
Collection *coll_mul_fields;
|
||||
|
||||
|
@ -480,8 +480,6 @@ TEST_F(CollectionTest, WildcardQuery) {
|
||||
nlohmann::json results = collection->search("*", query_fields, "points:>0", {}, sort_fields, 0, 3, 1, FREQUENCY,
|
||||
false).get();
|
||||
|
||||
LOG(INFO) << results;
|
||||
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_EQ(25, results["found"].get<uint32_t>());
|
||||
|
||||
@ -1639,10 +1637,10 @@ TEST_F(CollectionTest, IndexingWithBadData) {
|
||||
const Option<nlohmann::json> & empty_facet_field_op = sample_collection->add(doc_str);
|
||||
ASSERT_TRUE(empty_facet_field_op.ok());
|
||||
|
||||
doc_str = "{\"name\": \"foo\", \"age\": \"34\", \"tags\": [], \"average\": 34 }";
|
||||
doc_str = "{\"name\": \"foo\", \"age\": [\"34\"], \"tags\": [], \"average\": 34 }";
|
||||
const Option<nlohmann::json> & bad_default_sorting_field_op1 = sample_collection->add(doc_str);
|
||||
ASSERT_FALSE(bad_default_sorting_field_op1.ok());
|
||||
ASSERT_STREQ("Default sorting field `age` must be a single valued numerical field.", bad_default_sorting_field_op1.error().c_str());
|
||||
ASSERT_STREQ("Field `age` must be an int32.", bad_default_sorting_field_op1.error().c_str());
|
||||
|
||||
doc_str = "{\"name\": \"foo\", \"tags\": [], \"average\": 34 }";
|
||||
const Option<nlohmann::json> & bad_default_sorting_field_op3 = sample_collection->add(doc_str);
|
||||
@ -2691,8 +2689,6 @@ TEST_F(CollectionTest, MultiFieldRelevance) {
|
||||
auto results = coll1->search("Dustin Kensrue Down There by the Train",
|
||||
{"title", "artist"}, "", {}, {}, 0, 10, 1, FREQUENCY).get();
|
||||
|
||||
LOG(INFO) << results;
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user