Merge branch 'v0.26-filter' into v0.26-filter

This commit is contained in:
Kishore Nallan 2023-06-28 21:12:30 +05:30 committed by GitHub
commit 052d415af6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 272 additions and 29 deletions

View File

@ -207,7 +207,8 @@ private:
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query, bool is_group_by_query = false) const;
bool is_wildcard_query,const bool is_vector_query,
bool is_group_by_query = false) const;
Option<bool> persist_collection_meta();
@ -462,7 +463,8 @@ public:
const text_match_type_t match_type = max_score,
const size_t facet_sample_percent = 100,
const size_t facet_sample_threshold = 0,
const size_t page_offset = UINT32_MAX) const;
const size_t page_offset = UINT32_MAX,
const size_t vector_query_hits = 250) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -550,6 +550,8 @@ namespace sort_field_const {
static const std::string precision = "precision";
static const std::string missing_values = "missing_values";
static const std::string vector_distance = "_vector_distance";
}
struct sort_by {

View File

@ -347,6 +347,7 @@ private:
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> vector_distance_sentinel_value;
// Internal utility functions
@ -917,7 +918,7 @@ public:
size_t filter_index,
int64_t max_field_match_score,
int64_t* scores,
int64_t& match_score_index) const;
int64_t& match_score_index, float vector_distance = 0) const;
void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<std::string>& group_by_fields,

View File

@ -742,6 +742,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const bool is_group_by_query) const {
size_t num_sort_expressions = 0;
@ -916,7 +917,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found) {
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
@ -930,6 +931,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
std::string error = "group_by parameters should not be empty when using sort_by group_found";
return Option<bool>(404, error);
}
if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) {
std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.";
return Option<bool>(404, error);
}
StringUtils::toupper(sort_field_std.order);
@ -952,6 +958,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
}
if(is_vector_query) {
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
}
if(!default_sorting_field.empty()) {
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
} else {
@ -960,9 +970,15 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
bool found_match_score = false;
bool found_vector_distance = false;
for(const auto & sort_field : sort_fields_std) {
if(sort_field.name == sort_field_const::text_match) {
found_match_score = true;
}
if(sort_field.name == sort_field_const::vector_distance) {
found_vector_distance = true;
}
if(found_match_score && found_vector_distance) {
break;
}
}
@ -971,6 +987,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
}
if(!found_vector_distance && is_vector_query && sort_fields.size() < 3) {
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
}
if(sort_fields_std.size() > 3) {
std::string message = "Only upto 3 sort_by fields can be specified.";
return Option<bool>(422, message);
@ -1087,7 +1107,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const text_match_type_t match_type,
const size_t facet_sample_percent,
const size_t facet_sample_threshold,
const size_t page_offset) const {
const size_t page_offset,
const size_t vector_query_hits) const {
std::shared_lock lock(mutex);
@ -1228,6 +1249,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
vector_query._reset();
vector_query.values = embedding;
vector_query.field_name = field_name;
vector_query.k = vector_query_hits;
continue;
}
@ -1457,10 +1479,11 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
bool is_wildcard_query = (query == "*");
bool is_group_by_query = group_by_fields.size() > 0;
bool is_vector_query = !vector_query.field_name.empty();
if(curated_sort_by.empty()) {
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
sort_fields_std, is_wildcard_query, is_group_by_query);
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -1472,7 +1495,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
sort_fields_std, is_wildcard_query, is_group_by_query);
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -1923,7 +1946,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
if(!vector_query.field_name.empty() && query == "*") {
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
wrapper_doc["vector_distance"] = field_order_kv->vector_distance;
}
hits_array.push_back(wrapper_doc);

View File

@ -669,6 +669,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *MAX_FACET_VALUES = "max_facet_values";
const char *VECTOR_QUERY = "vector_query";
const char *VECTOR_QUERY_HITS = "vector_query_hits";
const char *GROUP_BY = "group_by";
const char *GROUP_LIMIT = "group_limit";
@ -821,6 +822,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
size_t max_extra_suffix = INT16_MAX;
bool enable_highlight_v1 = true;
text_match_type_t match_type = max_score;
size_t vector_query_hits = 250;
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
@ -847,6 +849,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{FILTER_CURATED_HITS, &filter_curated_hits_option},
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{VECTOR_QUERY_HITS, &vector_query_hits},
};
std::unordered_map<std::string, std::string*> str_values = {

View File

@ -547,7 +547,9 @@ bool HttpServer::is_write_request(const std::string& root_resource, const std::s
return false;
}
bool write_free_request = (root_resource == "multi_search" || root_resource == "operations");
bool write_free_request = (root_resource == "multi_search" || root_resource == "proxy" ||
root_resource == "operations");
if(!write_free_request &&
(http_method == "POST" || http_method == "PUT" ||
http_method == "DELETE" || http_method == "PATCH")) {

View File

@ -43,6 +43,7 @@ spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_distance_sentinel_value;
struct token_posting_t {
uint32_t token_id;
@ -2337,12 +2338,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
int64_t scores[3] = {0};
scores[0] = -float_to_int64_t(vec_dist_score);
int64_t match_score_index = -1;
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, 0, scores, match_score_index, vec_dist_score);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
kv.vector_distance = vec_dist_score;
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
@ -2563,7 +2564,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
VectorFilterFunctor filterFunctor(filter_result_iterator);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
auto k = std::max<size_t>(vector_query.k, fetch_size);
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = std::max<size_t>(vector_query.k, default_k);
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
@ -2599,33 +2602,63 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
continue;
}
// (1 / rank_of_document) * WEIGHT)
result->text_match_score = result->scores[result->match_score_index];
LOG(INFO) << "SEQ_ID: " << result->key << ", score: " << result->text_match_score;
result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT);
}
for(int i = 0; i < vec_results.size(); i++) {
auto& vec_result = vec_results[i];
auto doc_id = vec_result.first;
std::vector<uint32_t> vec_search_ids; // list of IDs found only in vector search
for(size_t res_index = 0; res_index < vec_results.size(); res_index++) {
auto& vec_result = vec_results[res_index];
auto doc_id = vec_result.first;
auto result_it = topster->kv_map.find(doc_id);
if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) {
if(result_it != topster->kv_map.end()) {
if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) {
continue;
}
// result overlaps with keyword search: we have to combine the scores
auto result = result_it->second;
// old_score + (1 / rank_of_document) * WEIGHT)
result->vector_distance = vec_result.second;
result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT));
result->scores[result->match_score_index] = float_to_int64_t(
(int64_t_to_float(result->scores[result->match_score_index])) +
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
for(size_t i = 0;i < 3; i++) {
if(field_values[i] == &vector_distance_sentinel_value) {
result->scores[i] = float_to_int64_t(vec_result.second);
}
if(sort_order[i] == -1) {
result->scores[i] = -result->scores[i];
}
}
} else {
int64_t scores[3] = {0};
// Result has been found only in vector search: we have to add it to both KV and result_ids
// (1 / rank_of_document) * WEIGHT)
scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT);
int64_t match_score_index = 0;
int64_t scores[3] = {0};
int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT);
int64_t match_score_index = -1;
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
kv.vector_distance = vec_result.second;
topster->add(&kv);
++all_result_ids_len;
vec_search_ids.push_back(doc_id);
}
}
if(!vec_search_ids.empty()) {
uint32_t* new_all_result_ids = nullptr;
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &vec_search_ids[0],
vec_search_ids.size(), &new_all_result_ids);
delete[] all_result_ids;
all_result_ids = new_all_result_ids;
}
}
}
@ -3589,7 +3622,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const {
int64_t* scores, int64_t& match_score_index, float vector_distance) const {
int64_t geopoint_distances[3];
@ -3684,6 +3717,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
scores[0] = int64_t(found);
} else if(field_values[0] == &vector_distance_sentinel_value) {
scores[0] = float_to_int64_t(vector_distance);
} else {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
@ -3740,6 +3775,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
scores[1] = int64_t(found);
} else if(field_values[1] == &vector_distance_sentinel_value) {
scores[1] = float_to_int64_t(vector_distance);
} else {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
@ -3792,6 +3829,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
scores[2] = int64_t(found);
} else if(field_values[2] == &vector_distance_sentinel_value) {
scores[2] = float_to_int64_t(vector_distance);
} else {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
@ -4490,15 +4529,14 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
return;
}
sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids);
} else if(sort_fields_std[i].name == sort_field_const::vector_distance) {
field_values[i] = &vector_distance_sentinel_value;
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);

View File

@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
TextEmbedder::TextEmbedder(const nlohmann::json& model_config) {
auto model_name = model_config["model_name"].get<std::string>();
LOG(INFO) << "Loading model from remote: " << model_name;
LOG(INFO) << "Initializing remote embedding model: " << model_name;
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
if(model_namespace == "openai") {

View File

@ -2246,3 +2246,145 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
}
TEST_F(CollectionSortingTest, AscendingVectorDistance) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "points", "type": "float[]", "num_dim": 2}
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
std::vector<std::vector<float>> points = {
{3.0, 4.0},
{9.0, 21.0},
{8.0, 15.0},
{1.0, 1.0},
{5.0, 7.0}
};
for(size_t i = 0; i < points.size(); i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = points[i];
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
std::vector<sort_by> sort_fields = {
sort_by("_vector_distance", "asc"),
};
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "points:([8.0, 15.0])").get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"2", "1", "4", "0", "3"};
for(size_t i = 0; i < expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
}
TEST_F(CollectionSortingTest, DescendingVectorDistance) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "points", "type": "float[]", "num_dim": 2}
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
std::vector<std::vector<float>> points = {
{3.0, 4.0},
{9.0, 21.0},
{8.0, 15.0},
{1.0, 1.0},
{5.0, 7.0}
};
for(size_t i = 0; i < points.size(); i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = points[i];
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
std::vector<sort_by> sort_fields = {
sort_by("_vector_distance", "DESC"),
};
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "points:([8.0, 15.0])").get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "1", "2"};
for(size_t i = 0; i < expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
}
TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "points", "type": "float[]", "num_dim": 2}
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
std::vector<std::vector<float>> points = {
{1.0, 1.0},
{2.0, 2.0},
{3.0, 3.0},
{4.0, 4.0},
{5.0, 5.0},
};
for(size_t i = 0; i < points.size(); i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = points[i];
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
std::vector<sort_by> sort_fields = {
sort_by("_vector_distance", "desc"),
};
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(results.ok());
ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error());
}

View File

@ -711,12 +711,42 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
4, {off}, 32767, 32767, 2,
false, true, "vec:(" + dummy_vec_string +")");
ASSERT_EQ(true, results_op.ok());
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
ASSERT_EQ(1, results_op.get()["hits"].size());
}
TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "name", "type": "string", "facet": true},
{"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
Collection* coll1 = collectionManager.create_collection(schema).get();
nlohmann::json doc;
doc["name"] = "john doe";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
auto results_op = coll1->search("zzz", {"name", "vec"}, "", {"name"}, {}, {0}, 20, 1, FREQUENCY, {true},
Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2);
ASSERT_EQ(true, results_op.ok());
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
ASSERT_EQ(1, results_op.get()["hits"].size());
ASSERT_EQ(1, results_op.get()["facet_counts"].size());
ASSERT_EQ(4, results_op.get()["facet_counts"][0].size());
ASSERT_EQ("name", results_op.get()["facet_counts"][0]["field_name"]);
}
TEST_F(CollectionVectorTest, DistanceThresholdTest) {
nlohmann::json schema = R"({
"name": "test",