mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 06:40:30 +08:00
Merge branch 'v0.26-filter' into v0.26-filter
This commit is contained in:
commit
052d415af6
@ -207,7 +207,8 @@ private:
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query, bool is_group_by_query = false) const;
|
||||
bool is_wildcard_query,const bool is_vector_query,
|
||||
bool is_group_by_query = false) const;
|
||||
|
||||
|
||||
Option<bool> persist_collection_meta();
|
||||
@ -462,7 +463,8 @@ public:
|
||||
const text_match_type_t match_type = max_score,
|
||||
const size_t facet_sample_percent = 100,
|
||||
const size_t facet_sample_threshold = 0,
|
||||
const size_t page_offset = UINT32_MAX) const;
|
||||
const size_t page_offset = UINT32_MAX,
|
||||
const size_t vector_query_hits = 250) const;
|
||||
|
||||
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
|
||||
|
||||
|
@ -550,6 +550,8 @@ namespace sort_field_const {
|
||||
static const std::string precision = "precision";
|
||||
|
||||
static const std::string missing_values = "missing_values";
|
||||
|
||||
static const std::string vector_distance = "_vector_distance";
|
||||
}
|
||||
|
||||
struct sort_by {
|
||||
|
@ -347,6 +347,7 @@ private:
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> vector_distance_sentinel_value;
|
||||
|
||||
// Internal utility functions
|
||||
|
||||
@ -917,7 +918,7 @@ public:
|
||||
size_t filter_index,
|
||||
int64_t max_field_match_score,
|
||||
int64_t* scores,
|
||||
int64_t& match_score_index) const;
|
||||
int64_t& match_score_index, float vector_distance = 0) const;
|
||||
|
||||
void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, const std::vector<std::string>& group_by_fields,
|
||||
|
@ -742,6 +742,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const bool is_group_by_query) const {
|
||||
|
||||
size_t num_sort_expressions = 0;
|
||||
@ -916,7 +917,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found) {
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
|
||||
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
@ -930,6 +931,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
std::string error = "group_by parameters should not be empty when using sort_by group_found";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) {
|
||||
std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
StringUtils::toupper(sort_field_std.order);
|
||||
|
||||
@ -952,6 +958,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
|
||||
}
|
||||
|
||||
if(is_vector_query) {
|
||||
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
|
||||
}
|
||||
|
||||
if(!default_sorting_field.empty()) {
|
||||
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
|
||||
} else {
|
||||
@ -960,9 +970,15 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
bool found_match_score = false;
|
||||
bool found_vector_distance = false;
|
||||
for(const auto & sort_field : sort_fields_std) {
|
||||
if(sort_field.name == sort_field_const::text_match) {
|
||||
found_match_score = true;
|
||||
}
|
||||
if(sort_field.name == sort_field_const::vector_distance) {
|
||||
found_vector_distance = true;
|
||||
}
|
||||
if(found_match_score && found_vector_distance) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -971,6 +987,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
|
||||
}
|
||||
|
||||
if(!found_vector_distance && is_vector_query && sort_fields.size() < 3) {
|
||||
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
|
||||
}
|
||||
|
||||
if(sort_fields_std.size() > 3) {
|
||||
std::string message = "Only upto 3 sort_by fields can be specified.";
|
||||
return Option<bool>(422, message);
|
||||
@ -1087,7 +1107,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const text_match_type_t match_type,
|
||||
const size_t facet_sample_percent,
|
||||
const size_t facet_sample_threshold,
|
||||
const size_t page_offset) const {
|
||||
const size_t page_offset,
|
||||
const size_t vector_query_hits) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -1228,6 +1249,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
vector_query.k = vector_query_hits;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1457,10 +1479,11 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
bool is_wildcard_query = (query == "*");
|
||||
bool is_group_by_query = group_by_fields.size() > 0;
|
||||
bool is_vector_query = !vector_query.field_name.empty();
|
||||
|
||||
if(curated_sort_by.empty()) {
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1472,7 +1495,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1923,7 +1946,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty() && query == "*") {
|
||||
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
|
||||
wrapper_doc["vector_distance"] = field_order_kv->vector_distance;
|
||||
}
|
||||
|
||||
hits_array.push_back(wrapper_doc);
|
||||
|
@ -669,6 +669,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const char *MAX_FACET_VALUES = "max_facet_values";
|
||||
|
||||
const char *VECTOR_QUERY = "vector_query";
|
||||
const char *VECTOR_QUERY_HITS = "vector_query_hits";
|
||||
|
||||
const char *GROUP_BY = "group_by";
|
||||
const char *GROUP_LIMIT = "group_limit";
|
||||
@ -821,6 +822,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
size_t max_extra_suffix = INT16_MAX;
|
||||
bool enable_highlight_v1 = true;
|
||||
text_match_type_t match_type = max_score;
|
||||
size_t vector_query_hits = 250;
|
||||
|
||||
size_t facet_sample_percent = 100;
|
||||
size_t facet_sample_threshold = 0;
|
||||
@ -847,6 +849,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{FILTER_CURATED_HITS, &filter_curated_hits_option},
|
||||
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
|
||||
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
|
||||
{VECTOR_QUERY_HITS, &vector_query_hits},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string*> str_values = {
|
||||
|
@ -547,7 +547,9 @@ bool HttpServer::is_write_request(const std::string& root_resource, const std::s
|
||||
return false;
|
||||
}
|
||||
|
||||
bool write_free_request = (root_resource == "multi_search" || root_resource == "operations");
|
||||
bool write_free_request = (root_resource == "multi_search" || root_resource == "proxy" ||
|
||||
root_resource == "operations");
|
||||
|
||||
if(!write_free_request &&
|
||||
(http_method == "POST" || http_method == "PUT" ||
|
||||
http_method == "DELETE" || http_method == "PATCH")) {
|
||||
|
@ -43,6 +43,7 @@ spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_distance_sentinel_value;
|
||||
|
||||
struct token_posting_t {
|
||||
uint32_t token_id;
|
||||
@ -2337,12 +2338,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
}
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
scores[0] = -float_to_int64_t(vec_dist_score);
|
||||
int64_t match_score_index = -1;
|
||||
|
||||
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
|
||||
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, 0, scores, match_score_index, vec_dist_score);
|
||||
|
||||
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
|
||||
kv.vector_distance = vec_dist_score;
|
||||
int ret = topster->add(&kv);
|
||||
|
||||
if(group_limit != 0 && ret < 2) {
|
||||
@ -2563,7 +2564,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
VectorFilterFunctor filterFunctor(filter_result_iterator);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
auto k = std::max<size_t>(vector_query.k, fetch_size);
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
size_t default_k = 100;
|
||||
auto k = std::max<size_t>(vector_query.k, default_k);
|
||||
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
@ -2599,33 +2602,63 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
continue;
|
||||
}
|
||||
// (1 / rank_of_document) * WEIGHT)
|
||||
|
||||
result->text_match_score = result->scores[result->match_score_index];
|
||||
LOG(INFO) << "SEQ_ID: " << result->key << ", score: " << result->text_match_score;
|
||||
result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT);
|
||||
}
|
||||
|
||||
for(int i = 0; i < vec_results.size(); i++) {
|
||||
auto& vec_result = vec_results[i];
|
||||
auto doc_id = vec_result.first;
|
||||
std::vector<uint32_t> vec_search_ids; // list of IDs found only in vector search
|
||||
|
||||
for(size_t res_index = 0; res_index < vec_results.size(); res_index++) {
|
||||
auto& vec_result = vec_results[res_index];
|
||||
auto doc_id = vec_result.first;
|
||||
auto result_it = topster->kv_map.find(doc_id);
|
||||
|
||||
if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) {
|
||||
if(result_it != topster->kv_map.end()) {
|
||||
if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// result overlaps with keyword search: we have to combine the scores
|
||||
|
||||
auto result = result_it->second;
|
||||
// old_score + (1 / rank_of_document) * WEIGHT)
|
||||
result->vector_distance = vec_result.second;
|
||||
result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
result->scores[result->match_score_index] = float_to_int64_t(
|
||||
(int64_t_to_float(result->scores[result->match_score_index])) +
|
||||
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
|
||||
for(size_t i = 0;i < 3; i++) {
|
||||
if(field_values[i] == &vector_distance_sentinel_value) {
|
||||
result->scores[i] = float_to_int64_t(vec_result.second);
|
||||
}
|
||||
|
||||
if(sort_order[i] == -1) {
|
||||
result->scores[i] = -result->scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
int64_t scores[3] = {0};
|
||||
// Result has been found only in vector search: we have to add it to both KV and result_ids
|
||||
// (1 / rank_of_document) * WEIGHT)
|
||||
scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT);
|
||||
int64_t match_score_index = 0;
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT);
|
||||
int64_t match_score_index = -1;
|
||||
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
|
||||
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
|
||||
kv.vector_distance = vec_result.second;
|
||||
topster->add(&kv);
|
||||
++all_result_ids_len;
|
||||
vec_search_ids.push_back(doc_id);
|
||||
}
|
||||
}
|
||||
|
||||
if(!vec_search_ids.empty()) {
|
||||
uint32_t* new_all_result_ids = nullptr;
|
||||
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &vec_search_ids[0],
|
||||
vec_search_ids.size(), &new_all_result_ids);
|
||||
delete[] all_result_ids;
|
||||
all_result_ids = new_all_result_ids;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3589,7 +3622,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
|
||||
int64_t* scores, int64_t& match_score_index) const {
|
||||
int64_t* scores, int64_t& match_score_index, float vector_distance) const {
|
||||
|
||||
int64_t geopoint_distances[3];
|
||||
|
||||
@ -3684,6 +3717,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[0] = int64_t(found);
|
||||
} else if(field_values[0] == &vector_distance_sentinel_value) {
|
||||
scores[0] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
@ -3740,6 +3775,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[1] = int64_t(found);
|
||||
} else if(field_values[1] == &vector_distance_sentinel_value) {
|
||||
scores[1] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
@ -3792,6 +3829,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[2] = int64_t(found);
|
||||
} else if(field_values[2] == &vector_distance_sentinel_value) {
|
||||
scores[2] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
@ -4490,15 +4529,14 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
field_values[i] = &seq_id_sentinel_value;
|
||||
} else if (sort_fields_std[i].name == sort_field_const::eval) {
|
||||
field_values[i] = &eval_sentinel_value;
|
||||
|
||||
auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root);
|
||||
auto filter_init_op = filter_result_iterator.init_status();
|
||||
if (!filter_init_op.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids);
|
||||
|
||||
} else if(sort_fields_std[i].name == sort_field_const::vector_distance) {
|
||||
field_values[i] = &vector_distance_sentinel_value;
|
||||
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
|
||||
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
|
||||
geopoint_indices.push_back(i);
|
||||
|
@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
|
||||
|
||||
TextEmbedder::TextEmbedder(const nlohmann::json& model_config) {
|
||||
auto model_name = model_config["model_name"].get<std::string>();
|
||||
LOG(INFO) << "Loading model from remote: " << model_name;
|
||||
LOG(INFO) << "Initializing remote embedding model: " << model_name;
|
||||
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
|
||||
|
||||
if(model_namespace == "openai") {
|
||||
|
@ -2246,3 +2246,145 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, AscendingVectorDistance) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{3.0, 4.0},
|
||||
{9.0, 21.0},
|
||||
{8.0, 15.0},
|
||||
{1.0, 1.0},
|
||||
{5.0, 7.0}
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "asc"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "points:([8.0, 15.0])").get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"2", "1", "4", "0", "3"};
|
||||
for(size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, DescendingVectorDistance) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{3.0, 4.0},
|
||||
{9.0, 21.0},
|
||||
{8.0, 15.0},
|
||||
{1.0, 1.0},
|
||||
{5.0, 7.0}
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "points:([8.0, 15.0])").get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"3", "0", "4", "1", "2"};
|
||||
|
||||
for(size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{1.0, 1.0},
|
||||
{2.0, 2.0},
|
||||
{3.0, 3.0},
|
||||
{4.0, 4.0},
|
||||
{5.0, 5.0},
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "desc"),
|
||||
};
|
||||
|
||||
|
||||
|
||||
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
|
||||
ASSERT_FALSE(results.ok());
|
||||
|
||||
ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error());
|
||||
}
|
@ -711,12 +711,42 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:(" + dummy_vec_string +")");
|
||||
ASSERT_EQ(true, results_op.ok());
|
||||
|
||||
|
||||
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results_op.get()["hits"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string", "facet": true},
|
||||
{"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "john doe";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
auto results_op = coll1->search("zzz", {"name", "vec"}, "", {"name"}, {}, {0}, 20, 1, FREQUENCY, {true},
|
||||
Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2);
|
||||
ASSERT_EQ(true, results_op.ok());
|
||||
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results_op.get()["hits"].size());
|
||||
ASSERT_EQ(1, results_op.get()["facet_counts"].size());
|
||||
ASSERT_EQ(4, results_op.get()["facet_counts"][0].size());
|
||||
ASSERT_EQ("name", results_op.get()["facet_counts"][0]["field_name"]);
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, DistanceThresholdTest) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "test",
|
||||
|
Loading…
x
Reference in New Issue
Block a user