mirror of
https://github.com/typesense/typesense.git
synced 2025-05-16 19:55:21 +08:00
Allow results to be sorted on a float field.
This commit is contained in:
parent
3104dea42a
commit
d351523655
@ -3,6 +3,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <art.h>
|
||||
#include <number.h>
|
||||
#include <sparsepp.h>
|
||||
#include <store.h>
|
||||
#include <topster.h>
|
||||
@ -52,7 +53,7 @@ private:
|
||||
|
||||
spp::sparse_hash_map<std::string, field> facet_schema;
|
||||
|
||||
std::vector<field> sort_fields;
|
||||
spp::sparse_hash_map<std::string, field> sort_schema;
|
||||
|
||||
Store* store;
|
||||
|
||||
@ -60,7 +61,7 @@ private:
|
||||
|
||||
spp::sparse_hash_map<std::string, facet_value> facet_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t>*> sort_index;
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, number_t>*> sort_index;
|
||||
|
||||
std::string token_ranking_field;
|
||||
|
||||
@ -84,14 +85,14 @@ private:
|
||||
size_t result_index, std::vector<std::vector<uint16_t>> &token_positions) const;
|
||||
|
||||
void search_field(std::string & query, const std::string & field, uint32_t *filter_ids, size_t filter_ids_length,
|
||||
std::vector<facet> & facets, const std::vector<sort_field> & sort_fields,
|
||||
std::vector<facet> & facets, const std::vector<sort_by> & sort_fields,
|
||||
const int num_typos, const size_t num_results,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, int & searched_queries_index,
|
||||
Topster<100> & topster, uint32_t** all_result_ids,
|
||||
size_t & all_result_ids_len, const token_ordering token_order = FREQUENCY, const bool prefix = false);
|
||||
|
||||
void search_candidates(uint32_t* filter_ids, size_t filter_ids_length, std::vector<facet> & facets,
|
||||
const std::vector<sort_field> & sort_fields, int & candidate_rank,
|
||||
const std::vector<sort_by> & sort_fields, int & candidate_rank,
|
||||
std::vector<std::vector<art_leaf*>> & token_to_candidates,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
|
||||
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
@ -155,7 +156,7 @@ public:
|
||||
|
||||
Option<nlohmann::json> search(std::string query, const std::vector<std::string> search_fields,
|
||||
const std::string & simple_filter_query, const std::vector<std::string> & facet_fields,
|
||||
const std::vector<sort_field> & sort_fields, const int num_typos,
|
||||
const std::vector<sort_by> & sort_fields, const int num_typos,
|
||||
const size_t per_page = 10, const size_t page = 1,
|
||||
const token_ordering token_order = FREQUENCY, const bool prefix = false);
|
||||
|
||||
@ -163,7 +164,7 @@ public:
|
||||
|
||||
Option<std::string> remove(const std::string & id);
|
||||
|
||||
void score_results(const std::vector<sort_field> & sort_fields, const int & query_index, const int & candidate_rank,
|
||||
void score_results(const std::vector<sort_by> & sort_fields, const int & query_index, const int & candidate_rank,
|
||||
Topster<100> &topster, const std::vector<art_leaf *> & query_suggestion, const uint32_t *result_ids,
|
||||
const size_t result_size) const;
|
||||
|
||||
|
@ -81,15 +81,15 @@ namespace sort_field_const {
|
||||
static const std::string desc = "DESC";
|
||||
}
|
||||
|
||||
struct sort_field {
|
||||
struct sort_by {
|
||||
std::string name;
|
||||
std::string order;
|
||||
|
||||
sort_field(const std::string & name, const std::string & order): name(name), order(order) {
|
||||
sort_by(const std::string & name, const std::string & order): name(name), order(order) {
|
||||
|
||||
}
|
||||
|
||||
sort_field& operator=(sort_field other) {
|
||||
sort_by& operator=(sort_by other) {
|
||||
name = other.name;
|
||||
order = other.order;
|
||||
return *this;
|
||||
|
80
include/number.h
Normal file
80
include/number.h
Normal file
@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
|
||||
#include <sparsepp.h>
|
||||
|
||||
struct number_t {
|
||||
bool is_float;
|
||||
union {
|
||||
float floatval;
|
||||
int64_t intval;
|
||||
};
|
||||
|
||||
number_t(): intval(0), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
number_t(bool is_float, float floatval): floatval(floatval), is_float(is_float) {
|
||||
|
||||
}
|
||||
|
||||
number_t(bool is_float, int64_t intval): intval(intval), is_float(is_float) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
number_t(float val): floatval(val), is_float(true) {
|
||||
|
||||
}
|
||||
|
||||
number_t(int64_t val): intval(val), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
inline void operator = (const float & val) {
|
||||
floatval = val;
|
||||
is_float = true;
|
||||
}
|
||||
|
||||
inline void operator = (const int64_t & val) {
|
||||
intval = val;
|
||||
is_float = false;
|
||||
}
|
||||
|
||||
inline bool operator == (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval == rhs.floatval;
|
||||
}
|
||||
return intval == rhs.intval;
|
||||
}
|
||||
|
||||
inline bool operator < (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval < rhs.floatval;
|
||||
}
|
||||
return intval < rhs.intval;
|
||||
}
|
||||
|
||||
inline bool operator > (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval > rhs.floatval;
|
||||
}
|
||||
return intval > rhs.intval;
|
||||
}
|
||||
|
||||
inline number_t operator * (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return number_t((float)(floatval * rhs.floatval));
|
||||
}
|
||||
return number_t((int64_t)(intval * rhs.intval));
|
||||
}
|
||||
|
||||
inline number_t operator-() {
|
||||
if(is_float) {
|
||||
floatval = -floatval;
|
||||
} else {
|
||||
intval = -intval;
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
85
include/person.h
Normal file
85
include/person.h
Normal file
@ -0,0 +1,85 @@
|
||||
#pragma once
|
||||
|
||||
struct person {
|
||||
bool is_float;
|
||||
union {
|
||||
int64_t intval;
|
||||
float floatval;
|
||||
};
|
||||
|
||||
person(): intval(0), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
person(bool is_float, float floatval): floatval(floatval), is_float(is_float) {
|
||||
|
||||
}
|
||||
|
||||
person(bool is_float, int64_t intval): intval(intval), is_float(is_float) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
person(float val): floatval(val), is_float(true) {
|
||||
|
||||
}
|
||||
|
||||
person(int64_t val): intval(val), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
inline void operator = (const float & val) {
|
||||
floatval = val;
|
||||
is_float = true;
|
||||
}
|
||||
|
||||
inline void operator = (const int64_t & val) {
|
||||
intval = val;
|
||||
is_float = false;
|
||||
}
|
||||
|
||||
inline bool operator == (const person & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval == rhs.floatval;
|
||||
}
|
||||
return intval == rhs.intval;
|
||||
}
|
||||
|
||||
inline bool operator < (const person & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval < rhs.floatval;
|
||||
}
|
||||
return intval < rhs.intval;
|
||||
}
|
||||
|
||||
inline bool operator > (const person & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval > rhs.floatval;
|
||||
}
|
||||
return intval > rhs.intval;
|
||||
}
|
||||
|
||||
inline person operator * (const person & rhs) const {
|
||||
if(is_float) {
|
||||
return person(floatval * rhs.floatval);
|
||||
}
|
||||
return person(intval * rhs.intval);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
// inject specialization of std::hash for Person into namespace std
|
||||
// ----------------------------------------------------------------
|
||||
template<>
|
||||
struct hash<person>
|
||||
{
|
||||
std::size_t operator()(person const &p) const
|
||||
{
|
||||
std::size_t seed = 0;
|
||||
spp::hash_combine(seed, p.is_float);
|
||||
spp::hash_combine(seed, p.intval);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
}
|
@ -6,56 +6,13 @@
|
||||
#include <algorithm>
|
||||
#include <sparsepp.h>
|
||||
#include <match_score.h>
|
||||
#include <number.h>
|
||||
|
||||
/*
|
||||
* Remembers the max-K elements seen so far using a min-heap
|
||||
*/
|
||||
template <size_t MAX_SIZE=100>
|
||||
struct Topster {
|
||||
struct number_t {
|
||||
bool is_float;
|
||||
union {
|
||||
float floatval;
|
||||
int64_t intval;
|
||||
};
|
||||
|
||||
number_t(): intval(0), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
number_t(float val): floatval(val), is_float(true) {
|
||||
|
||||
}
|
||||
|
||||
number_t(int64_t val): intval(val), is_float(false) {
|
||||
|
||||
}
|
||||
|
||||
inline void operator = (const float & val) {
|
||||
floatval = val;
|
||||
is_float = true;
|
||||
}
|
||||
|
||||
inline void operator = (const int64_t & val) {
|
||||
intval = val;
|
||||
is_float = false;
|
||||
}
|
||||
|
||||
inline bool operator < (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval < rhs.floatval;
|
||||
}
|
||||
return intval < rhs.intval;
|
||||
}
|
||||
|
||||
inline bool operator > (const number_t & rhs) const {
|
||||
if(is_float) {
|
||||
return floatval > rhs.floatval;
|
||||
}
|
||||
return intval > rhs.intval;
|
||||
}
|
||||
};
|
||||
|
||||
struct KV {
|
||||
uint16_t query_index;
|
||||
uint64_t key;
|
||||
|
@ -188,7 +188,7 @@ void get_search(http_req & req, http_res & res) {
|
||||
std::vector<std::string> facet_fields;
|
||||
StringUtils::split(req.params[FACET_BY], facet_fields, "&&");
|
||||
|
||||
std::vector<sort_field> sort_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
if(req.params.count(SORT_BY) != 0) {
|
||||
std::vector<std::string> sort_field_strs;
|
||||
StringUtils::split(req.params[SORT_BY], sort_field_strs, ",");
|
||||
@ -206,7 +206,7 @@ void get_search(http_req & req, http_res & res) {
|
||||
}
|
||||
|
||||
StringUtils::toupper(expression_parts[1]);
|
||||
sort_fields.push_back(sort_field(expression_parts[0], expression_parts[1]));
|
||||
sort_fields.push_back(sort_by(expression_parts[0], expression_parts[1]));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ Collection::Collection(const std::string name, const uint32_t collection_id, con
|
||||
const std::vector<field> &search_fields, const std::vector<field> & facet_fields,
|
||||
const std::vector<field> & sort_fields, const std::string token_ranking_field):
|
||||
name(name), collection_id(collection_id), next_seq_id(next_seq_id), store(store),
|
||||
sort_fields(sort_fields), token_ranking_field(token_ranking_field) {
|
||||
token_ranking_field(token_ranking_field) {
|
||||
|
||||
for(const field& field: search_fields) {
|
||||
art_tree *t = new art_tree;
|
||||
@ -27,8 +27,9 @@ Collection::Collection(const std::string name, const uint32_t collection_id, con
|
||||
}
|
||||
|
||||
for(const field & sort_field: sort_fields) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t> * doc_to_score = new spp::sparse_hash_map<uint32_t, int64_t>();
|
||||
spp::sparse_hash_map<uint32_t, number_t> * doc_to_score = new spp::sparse_hash_map<uint32_t, number_t>();
|
||||
sort_index.emplace(sort_field.name, doc_to_score);
|
||||
sort_schema.emplace(sort_field.name, sort_field);
|
||||
}
|
||||
|
||||
num_documents = 0;
|
||||
@ -218,7 +219,9 @@ Option<uint32_t> Collection::index_in_memory(const nlohmann::json &document, uin
|
||||
}
|
||||
}
|
||||
|
||||
for(const field & sort_field: sort_fields) {
|
||||
for(const std::pair<std::string, field> & field_pair: sort_schema) {
|
||||
const field & sort_field = field_pair.second;
|
||||
|
||||
if(document.count(sort_field.name) == 0) {
|
||||
return Option<>(400, "Field `" + sort_field.name + "` has been declared as a sort field in the schema, "
|
||||
"but is not found in the document.");
|
||||
@ -228,8 +231,13 @@ Option<uint32_t> Collection::index_in_memory(const nlohmann::json &document, uin
|
||||
return Option<>(400, "Sort field `" + sort_field.name + "` must be a number.");
|
||||
}
|
||||
|
||||
spp::sparse_hash_map<uint32_t, int64_t> *doc_to_score = sort_index.at(sort_field.name);
|
||||
doc_to_score->emplace(seq_id, document[sort_field.name].get<int64_t>());
|
||||
spp::sparse_hash_map<uint32_t, number_t> *doc_to_score = sort_index.at(sort_field.name);
|
||||
|
||||
if(document[sort_field.name].is_number_integer()) {
|
||||
doc_to_score->emplace(seq_id, document[sort_field.name].get<int64_t>());
|
||||
} else {
|
||||
doc_to_score->emplace(seq_id, document[sort_field.name].get<float>());
|
||||
}
|
||||
}
|
||||
|
||||
num_documents += 1;
|
||||
@ -401,7 +409,7 @@ void Collection::do_facets(std::vector<facet> & facets, uint32_t* result_ids, si
|
||||
}
|
||||
|
||||
void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_length, std::vector<facet> & facets,
|
||||
const std::vector<sort_field> & sort_fields, int & candidate_rank,
|
||||
const std::vector<sort_by> & sort_fields, int & candidate_rank,
|
||||
std::vector<std::vector<art_leaf*>> & token_to_candidates,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
|
||||
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
@ -649,7 +657,7 @@ Option<uint32_t> Collection::do_filtering(uint32_t** filter_ids_out, const std::
|
||||
|
||||
Option<nlohmann::json> Collection::search(std::string query, const std::vector<std::string> search_fields,
|
||||
const std::string & simple_filter_query, const std::vector<std::string> & facet_fields,
|
||||
const std::vector<sort_field> & sort_fields, const int num_typos,
|
||||
const std::vector<sort_by> & sort_fields, const int num_typos,
|
||||
const size_t per_page, const size_t page,
|
||||
const token_ordering token_order, const bool prefix) {
|
||||
nlohmann::json result = nlohmann::json::object();
|
||||
@ -680,9 +688,9 @@ Option<nlohmann::json> Collection::search(std::string query, const std::vector<s
|
||||
|
||||
// validate sort fields and standardize
|
||||
|
||||
std::vector<sort_field> sort_fields_std;
|
||||
std::vector<sort_by> sort_fields_std;
|
||||
|
||||
for(const sort_field & _sort_field: sort_fields) {
|
||||
for(const sort_by & _sort_field: sort_fields) {
|
||||
if(sort_index.count(_sort_field.name) == 0) {
|
||||
std::string error = "Could not find a sort field named `" + _sort_field.name + "` in the schema.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
@ -888,7 +896,7 @@ Option<nlohmann::json> Collection::search(std::string query, const std::vector<s
|
||||
5. Sort the docs based on some ranking criteria
|
||||
*/
|
||||
void Collection::search_field(std::string & query, const std::string & field, uint32_t *filter_ids, size_t filter_ids_length,
|
||||
std::vector<facet> & facets, const std::vector<sort_field> & sort_fields, const int num_typos,
|
||||
std::vector<facet> & facets, const std::vector<sort_by> & sort_fields, const int num_typos,
|
||||
const size_t num_results, std::vector<std::vector<art_leaf*>> & searched_queries,
|
||||
int & searched_queries_index, Topster<100> &topster, uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
const token_ordering token_order, const bool prefix) {
|
||||
@ -1048,7 +1056,7 @@ void Collection::log_leaves(const int cost, const std::string &token, const std:
|
||||
}
|
||||
}
|
||||
|
||||
void Collection::score_results(const std::vector<sort_field> & sort_fields, const int & query_index, const int & candidate_rank,
|
||||
void Collection::score_results(const std::vector<sort_by> & sort_fields, const int & query_index, const int & candidate_rank,
|
||||
Topster<100> & topster, const std::vector<art_leaf *> &query_suggestion,
|
||||
const uint32_t *result_ids, const size_t result_size) const {
|
||||
|
||||
@ -1061,25 +1069,43 @@ void Collection::score_results(const std::vector<sort_field> & sort_fields, cons
|
||||
leaf_to_indices.emplace(token_leaf, indices);
|
||||
}
|
||||
|
||||
spp::sparse_hash_map<uint32_t, int64_t> * primary_rank_scores = nullptr;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> * secondary_rank_scores = nullptr;
|
||||
spp::sparse_hash_map<uint32_t, number_t> * primary_rank_scores = nullptr;
|
||||
spp::sparse_hash_map<uint32_t, number_t> * secondary_rank_scores = nullptr;
|
||||
|
||||
// Used for asc/desc ordering. NOTE: Topster keeps biggest keys (i.e. it's desc in nature)
|
||||
int64_t primary_rank_factor = 1;
|
||||
int64_t secondary_rank_factor = 1;
|
||||
number_t primary_rank_factor;
|
||||
number_t secondary_rank_factor;
|
||||
|
||||
if(sort_fields.size() > 0) {
|
||||
// assumed that rank field exists in the index - checked earlier in the chain
|
||||
primary_rank_scores = sort_index.at(sort_fields[0].name);
|
||||
|
||||
// initialize primary_rank_factor
|
||||
field sort_field = sort_schema.at(sort_fields[0].name);
|
||||
if(sort_field.is_integer()) {
|
||||
primary_rank_factor = ((int64_t) 1);
|
||||
} else {
|
||||
primary_rank_factor = ((float) 1);
|
||||
}
|
||||
|
||||
if(sort_fields[0].order == sort_field_const::asc) {
|
||||
primary_rank_factor = -1;
|
||||
primary_rank_factor = -primary_rank_factor;
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 1) {
|
||||
secondary_rank_scores = sort_index.at(sort_fields[1].name);
|
||||
|
||||
// initialize secondary_rank_factor
|
||||
field sort_field = sort_schema.at(sort_fields[1].name);
|
||||
if(sort_field.is_integer()) {
|
||||
secondary_rank_factor = ((int64_t) 1);
|
||||
} else {
|
||||
secondary_rank_factor = ((float) 1);
|
||||
}
|
||||
|
||||
if(sort_fields[1].order == sort_field_const::asc) {
|
||||
secondary_rank_factor = -1;
|
||||
secondary_rank_factor = -secondary_rank_factor;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1105,13 +1131,15 @@ void Collection::score_results(const std::vector<sort_field> & sort_fields, cons
|
||||
(candidate_rank_score << 8) +
|
||||
(MAX_SEARCH_TOKENS - mscore.distance);
|
||||
|
||||
int64_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ?
|
||||
primary_rank_scores->at(seq_id) : 0;
|
||||
int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ?
|
||||
secondary_rank_scores->at(seq_id) : 0;
|
||||
const int64_t default_score = 0;
|
||||
number_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ?
|
||||
primary_rank_scores->at(seq_id) : default_score;
|
||||
number_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ?
|
||||
secondary_rank_scores->at(seq_id) : default_score;
|
||||
|
||||
topster.add(seq_id, query_index, match_score, primary_rank_factor * primary_rank_score,
|
||||
secondary_rank_factor * secondary_rank_score);
|
||||
const number_t & primary_rank_value = primary_rank_score * primary_rank_factor;
|
||||
const number_t & secondary_rank_value = secondary_rank_score * secondary_rank_factor;
|
||||
topster.add(seq_id, query_index, match_score, primary_rank_value, secondary_rank_value);
|
||||
|
||||
/*std::cout << "candidate_rank_score: " << candidate_rank_score << ", words_present: " << mscore.words_present
|
||||
<< ", match_score: " << match_score << ", primary_rank_score: " << primary_rank_score
|
||||
@ -1399,7 +1427,12 @@ std::vector<std::string> Collection::get_facet_fields() {
|
||||
}
|
||||
|
||||
std::vector<field> Collection::get_sort_fields() {
|
||||
return sort_fields;
|
||||
std::vector<field> sort_fields_copy;
|
||||
for(auto it = sort_schema.begin(); it != sort_schema.end(); ++it) {
|
||||
sort_fields_copy.push_back(it->second);
|
||||
}
|
||||
|
||||
return sort_fields_copy;
|
||||
}
|
||||
|
||||
spp::sparse_hash_map<std::string, field> Collection::get_schema() {
|
||||
|
@ -14,7 +14,7 @@ protected:
|
||||
std::vector<field> facet_fields;
|
||||
std::vector<field> sort_fields_index;
|
||||
|
||||
std::vector<sort_field> sort_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/coll_manager_test_db";
|
||||
@ -26,7 +26,7 @@ protected:
|
||||
|
||||
search_fields = {field("title", field_types::STRING), field("starring", field_types::STRING)};
|
||||
facet_fields = {field("starring", field_types::STRING)};
|
||||
sort_fields = { sort_field("points", "DESC") };
|
||||
sort_fields = { sort_by("points", "DESC") };
|
||||
sort_fields_index = { field("points", "INT32") };
|
||||
|
||||
collection1 = collectionManager.create_collection("collection1", search_fields, facet_fields,
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <algorithm>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "person.h"
|
||||
#include "number.h"
|
||||
|
||||
class CollectionTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -14,7 +16,7 @@ protected:
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
std::vector<field> facet_fields;
|
||||
std::vector<field> sort_fields_index;
|
||||
std::vector<sort_field> sort_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/collection";
|
||||
@ -29,7 +31,7 @@ protected:
|
||||
|
||||
query_fields = {"title"};
|
||||
facet_fields = { };
|
||||
sort_fields = { sort_field("points", "DESC") };
|
||||
sort_fields = { sort_by("points", "DESC") };
|
||||
sort_fields_index = { field("points", "INT32") };
|
||||
|
||||
collection = collectionManager.get_collection("collection");
|
||||
@ -94,7 +96,7 @@ TEST_F(CollectionTest, ExactSearchShouldBeStable) {
|
||||
}
|
||||
|
||||
// check ASC sorting
|
||||
std::vector<sort_field> sort_fields_asc = { sort_field("points", "ASC") };
|
||||
std::vector<sort_by> sort_fields_asc = { sort_by("points", "ASC") };
|
||||
|
||||
results = collection->search("the", query_fields, "", facets, sort_fields_asc, 0, 10).get();
|
||||
ASSERT_EQ(7, results["hits"].size());
|
||||
@ -135,7 +137,7 @@ TEST_F(CollectionTest, ExactPhraseSearch) {
|
||||
}
|
||||
|
||||
// Check ASC sort order
|
||||
std::vector<sort_field> sort_fields_asc = { sort_field("points", "ASC") };
|
||||
std::vector<sort_by> sort_fields_asc = { sort_by("points", "ASC") };
|
||||
results = collection->search("rocket launch", query_fields, "", facets, sort_fields_asc, 0, 10).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
ASSERT_EQ(5, results["found"].get<uint32_t>());
|
||||
@ -495,7 +497,7 @@ TEST_F(CollectionTest, FilterOnNumericFields) {
|
||||
std::vector<field> fields = {field("name", field_types::STRING), field("age", field_types::INT32),
|
||||
field("years", field_types::INT32_ARRAY),
|
||||
field("timestamps", field_types::INT64_ARRAY)};
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
@ -632,8 +634,8 @@ TEST_F(CollectionTest, FilterOnFloatFields) {
|
||||
field("top_3", field_types::FLOAT_ARRAY),
|
||||
field("rating", field_types::FLOAT)};
|
||||
std::vector<field> sort_fields_index = { field("rating", "FLOAT") };
|
||||
std::vector<sort_field> sort_fields_desc = { sort_field("rating", "DESC") };
|
||||
std::vector<sort_field> sort_fields_asc = { sort_field("rating", "ASC") };
|
||||
std::vector<sort_by> sort_fields_desc = { sort_by("rating", "DESC") };
|
||||
std::vector<sort_by> sort_fields_asc = { sort_by("rating", "ASC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
if(coll_array_fields == nullptr) {
|
||||
@ -673,7 +675,7 @@ TEST_F(CollectionTest, FilterOnFloatFields) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str()); //?
|
||||
}
|
||||
|
||||
// Searching on a float field, sorted desc by rating
|
||||
@ -761,6 +763,72 @@ TEST_F(CollectionTest, FilterOnFloatFields) {
|
||||
collectionManager.drop_collection("coll_array_fields");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, SortOnFloatFields) {
|
||||
Collection *coll_float_fields;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/float_documents.jsonl");
|
||||
std::vector<field> fields = {field("title", field_types::STRING), field("score", field_types::FLOAT)};
|
||||
std::vector<field> sort_fields_index = { field("score", "FLOAT"), field("average", "FLOAT") };
|
||||
std::vector<sort_by> sort_fields_desc = { sort_by("score", "DESC"), sort_by("average", "DESC") };
|
||||
|
||||
coll_float_fields = collectionManager.get_collection("coll_float_fields");
|
||||
if(coll_float_fields == nullptr) {
|
||||
coll_float_fields = collectionManager.create_collection("coll_float_fields", fields, facet_fields, sort_fields_index);
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
coll_float_fields->add(json_line);
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
query_fields = {"title"};
|
||||
std::vector<std::string> facets;
|
||||
nlohmann::json results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_desc, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(7, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"2", "0", "3", "1", "5", "4", "6"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
EXPECT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields_asc = { sort_by("score", "ASC"), sort_by("average", "ASC") };
|
||||
results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_asc, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(7, results["hits"].size());
|
||||
|
||||
ids = {"6", "4", "5", "1", "3", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
EXPECT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// second field by desc
|
||||
|
||||
std::vector<sort_by> sort_fields_asc_desc = { sort_by("score", "ASC"), sort_by("average", "DESC") };
|
||||
results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_asc_desc, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(7, results["hits"].size());
|
||||
|
||||
ids = {"5", "4", "6", "1", "3", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
EXPECT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
collectionManager.drop_collection("coll_float_fields");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, FilterOnTextFields) {
|
||||
Collection *coll_array_fields;
|
||||
|
||||
@ -770,7 +838,7 @@ TEST_F(CollectionTest, FilterOnTextFields) {
|
||||
field("tags", field_types::STRING_ARRAY)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
if(coll_array_fields == nullptr) {
|
||||
@ -842,7 +910,7 @@ TEST_F(CollectionTest, HandleBadlyFormedFilterQuery) {
|
||||
field("tags", field_types::STRING_ARRAY)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
if(coll_array_fields == nullptr) {
|
||||
@ -898,7 +966,7 @@ TEST_F(CollectionTest, FacetCounts) {
|
||||
facet_fields = {field("tags", field_types::STRING_ARRAY), field("name", field_types::STRING)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "DESC") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
if(coll_array_fields == nullptr) {
|
||||
@ -991,7 +1059,7 @@ TEST_F(CollectionTest, SortingOrder) {
|
||||
|
||||
query_fields = {"title"};
|
||||
std::vector<std::string> facets;
|
||||
sort_fields = { sort_field("points", "ASC") };
|
||||
sort_fields = { sort_by("points", "ASC") };
|
||||
nlohmann::json results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(10, results["hits"].size());
|
||||
|
||||
@ -1005,7 +1073,7 @@ TEST_F(CollectionTest, SortingOrder) {
|
||||
}
|
||||
|
||||
// limiting results to just 5, "ASC" keyword must be case insensitive
|
||||
sort_fields = { sort_field("points", "asc") };
|
||||
sort_fields = { sort_by("points", "asc") };
|
||||
results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 5, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
@ -1020,7 +1088,7 @@ TEST_F(CollectionTest, SortingOrder) {
|
||||
|
||||
// desc
|
||||
|
||||
sort_fields = { sort_field("points", "dEsc") };
|
||||
sort_fields = { sort_by("points", "dEsc") };
|
||||
results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(10, results["hits"].size());
|
||||
|
||||
@ -1062,7 +1130,7 @@ TEST_F(CollectionTest, SearchingWithMissingFields) {
|
||||
field("tags", field_types::STRING_ARRAY)};
|
||||
facet_fields = {field("tags", field_types::STRING_ARRAY), field("name", field_types::STRING)};
|
||||
std::vector<field> sort_fields_index = { field("age", "DESC") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields");
|
||||
if(coll_array_fields == nullptr) {
|
||||
@ -1097,11 +1165,11 @@ TEST_F(CollectionTest, SearchingWithMissingFields) {
|
||||
ASSERT_STREQ("Could not find a facet field named `timestamps` in the schema.", res_op.error().c_str());
|
||||
|
||||
// when a rank field is not defined in the schema
|
||||
res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_field("timestamps", "ASC") }, 0, 10);
|
||||
res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_by("timestamps", "ASC") }, 0, 10);
|
||||
ASSERT_EQ(400, res_op.code());
|
||||
ASSERT_STREQ("Could not find a sort field named `timestamps` in the schema.", res_op.error().c_str());
|
||||
|
||||
res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_field("_rank", "ASC") }, 0, 10);
|
||||
res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_by("_rank", "ASC") }, 0, 10);
|
||||
ASSERT_EQ(400, res_op.code());
|
||||
ASSERT_STREQ("Could not find a sort field named `_rank` in the schema.", res_op.error().c_str());
|
||||
|
||||
@ -1116,7 +1184,7 @@ TEST_F(CollectionTest, IndexingWithBadData) {
|
||||
facet_fields = {field("tags", field_types::STRING_ARRAY)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32"), field("average", "INT32") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
|
||||
sample_collection = collectionManager.get_collection("sample_collection");
|
||||
if(sample_collection == nullptr) {
|
||||
@ -1193,7 +1261,7 @@ TEST_F(CollectionTest, EmptyIndexShouldNotCrash) {
|
||||
facet_fields = {field("tags", field_types::STRING_ARRAY)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32"), field("average", "INT32") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
|
||||
empty_coll = collectionManager.get_collection("empty_coll");
|
||||
if(empty_coll == nullptr) {
|
||||
@ -1212,7 +1280,7 @@ TEST_F(CollectionTest, IdFieldShouldBeAString) {
|
||||
facet_fields = {field("tags", field_types::STRING_ARRAY)};
|
||||
|
||||
std::vector<field> sort_fields_index = { field("age", "INT32"), field("average", "INT32") };
|
||||
std::vector<sort_field> sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1");
|
||||
if(coll1 == nullptr) {
|
||||
@ -1241,7 +1309,7 @@ TEST_F(CollectionTest, DeletionOfADocument) {
|
||||
std::vector<field> search_fields = {field("title", field_types::STRING)};
|
||||
std::vector<std::string> query_fields = {"title"};
|
||||
std::vector<field> facet_fields = { };
|
||||
std::vector<sort_field> sort_fields = { sort_field("points", "DESC") };
|
||||
std::vector<sort_by> sort_fields = { sort_by("points", "DESC") };
|
||||
std::vector<field> sort_fields_index = { field("points", "INT32") };
|
||||
|
||||
Collection *collection_for_del;
|
||||
|
7
test/float_documents.jsonl
Normal file
7
test/float_documents.jsonl
Normal file
@ -0,0 +1,7 @@
|
||||
{"title": "Jeremy Howard", "score": 1.09, "average": 1.45}
|
||||
{"title": "Jeremy Howard", "score": -9.998, "average": -2.408 }
|
||||
{"title": "Jeremy Howard", "score": 7.812, "average": 0.001 }
|
||||
{"title": "Jeremy Howard", "score": 0.0, "average": 11.533 }
|
||||
{"title": "Jeremy Howard", "score": -9.999, "average": -11.38 }
|
||||
{"title": "Jeremy Howard", "score": -9.999, "average": 19.38 }
|
||||
{"title": "Jeremy Howard", "score": -9.999, "average": -21.38 }
|
@ -19,4 +19,10 @@ TEST(MatchScoreTest, ShouldPackTokenOffsets) {
|
||||
ASSERT_EQ(0, offset_diffs[1]);
|
||||
ASSERT_EQ(1, offset_diffs[2]);
|
||||
ASSERT_EQ(2, offset_diffs[3]);
|
||||
|
||||
uint16_t min_token_offset3[1] = {123};
|
||||
MatchScore::pack_token_offsets(min_token_offset3, 1, 0, offset_diffs);
|
||||
|
||||
ASSERT_EQ(1, offset_diffs[0]);
|
||||
ASSERT_EQ(0, offset_diffs[1]);
|
||||
}
|
@ -52,18 +52,18 @@ TEST(TopsterTest, StoreMaxFloatValuesWithoutRepetition) {
|
||||
float primary_attr;
|
||||
int64_t secondary_attr;
|
||||
} data[12] = {
|
||||
{0, 1, 11, 20.04, 30},
|
||||
{0, 2, 4, 20, 30},
|
||||
{2, 3, 7, 20, 30},
|
||||
{0, 4, 11, 20.05, 30},
|
||||
{0, 4, 11, 20.05, 30},
|
||||
{1, 5, 9, 24.50, 34},
|
||||
{0, 6, 6, 20, 30},
|
||||
{2, 7, 6, 22, 30},
|
||||
{1, 8, 9, 24.50, 30},
|
||||
{1, 8, 9, 24.50, 30},
|
||||
{0, 9, 8, 24.50, 30},
|
||||
{3, 10, 5, 20, 30},
|
||||
{0, 1, 11, 1.09, 30},
|
||||
{0, 2, 11, -20, 30},
|
||||
{2, 3, 11, -20, 30},
|
||||
{0, 4, 11, 7.812, 30},
|
||||
{0, 4, 11, 7.812, 30},
|
||||
{1, 5, 11, 0.0, 34},
|
||||
{0, 6, 11, -22, 30},
|
||||
{2, 7, 11, -22, 30},
|
||||
{1, 8, 11, -9.998, 30},
|
||||
{1, 8, 11, -9.998, 30},
|
||||
{0, 9, 11, -9.999, 30},
|
||||
{3, 10, 11, -20, 30},
|
||||
};
|
||||
|
||||
for(int i = 0; i < 12; i++) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user