mirror of
https://github.com/typesense/typesense.git
synced 2025-05-16 03:12:32 +08:00
Vector search basics.
This commit is contained in:
parent
a2a31e6c0b
commit
a98e5bacdd
@ -72,6 +72,7 @@ include(cmake/Jemalloc.cmake)
|
||||
include(cmake/s2.cmake)
|
||||
include(cmake/lrucache.cmake)
|
||||
include(cmake/kakasi.cmake)
|
||||
include(cmake/hnsw.cmake)
|
||||
|
||||
FIND_PACKAGE(OpenSSL 1.1.1 REQUIRED)
|
||||
FIND_PACKAGE(Snappy REQUIRED)
|
||||
@ -107,6 +108,7 @@ include_directories(${DEP_ROOT_DIR}/${S2_NAME}/src)
|
||||
include_directories(${DEP_ROOT_DIR}/${LRUCACHE_NAME}/include)
|
||||
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/build/include)
|
||||
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/data)
|
||||
include_directories(${DEP_ROOT_DIR}/${HNSW_NAME})
|
||||
|
||||
link_directories(/usr/local/lib)
|
||||
link_directories(${DEP_ROOT_DIR}/${GTEST_NAME}/googletest/build)
|
||||
|
15
cmake/hnsw.cmake
Normal file
15
cmake/hnsw.cmake
Normal file
@ -0,0 +1,15 @@
|
||||
# Download hnsw (header-only)
|
||||
|
||||
set(HNSW_VERSION b87f6230dbe59e874b3099cfcab689b42e887a20)
|
||||
set(HNSW_NAME hnswlib-${HNSW_VERSION})
|
||||
set(HNSW_TAR_PATH ${DEP_ROOT_DIR}/${HNSW_NAME}.tar.gz)
|
||||
|
||||
if(NOT EXISTS ${HNSW_TAR_PATH})
|
||||
message(STATUS "Downloading https://github.com/typesense/hnswlib/archive/${HNSW_VERSION}.tar.gz")
|
||||
file(DOWNLOAD https://github.com/typesense/hnswlib/archive/${HNSW_VERSION}.tar.gz ${HNSW_TAR_PATH})
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${DEP_ROOT_DIR}/${HNSW_NAME})
|
||||
message(STATUS "Extracting ${HNSW_NAME}...")
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf ${HNSW_TAR_PATH} WORKING_DIRECTORY ${DEP_ROOT_DIR}/)
|
||||
endif()
|
@ -412,7 +412,8 @@ public:
|
||||
const size_t max_extra_suffix = INT16_MAX,
|
||||
const size_t facet_query_num_typos = 2,
|
||||
const size_t filter_curated_hits_option = 2,
|
||||
const bool prioritize_token_position = false) const;
|
||||
const bool prioritize_token_position = false,
|
||||
const std::string& vector_query_str = "") const;
|
||||
|
||||
Option<bool> get_filter_ids(const std::string & simple_filter_query,
|
||||
std::vector<std::pair<size_t, uint32_t*>>& index_ids);
|
||||
|
@ -181,6 +181,8 @@ public:
|
||||
|
||||
static bool parse_sort_by_str(std::string sort_by_str, std::vector<sort_by>& sort_fields);
|
||||
|
||||
static bool parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query);
|
||||
|
||||
// symlinks
|
||||
Option<std::string> resolve_symlink(const std::string & symlink_name) const;
|
||||
|
||||
|
@ -45,6 +45,7 @@ namespace fields {
|
||||
static const std::string locale = "locale";
|
||||
static const std::string nested = "nested";
|
||||
static const std::string nested_array = "nested_array";
|
||||
static const std::string num_dim = "num_dim";
|
||||
}
|
||||
|
||||
struct field {
|
||||
@ -64,15 +65,17 @@ struct field {
|
||||
// third state is used to diff between array of object and array within object during write
|
||||
int nested_array;
|
||||
|
||||
size_t num_dim;
|
||||
|
||||
static constexpr int VAL_UNKNOWN = 2;
|
||||
|
||||
field() {}
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0) :
|
||||
int nested_array = 0, size_t num_dim = 0) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim) {
|
||||
|
||||
if(sort != -1) {
|
||||
this->sort = bool(sort);
|
||||
@ -546,6 +549,21 @@ struct sort_by {
|
||||
}
|
||||
};
|
||||
|
||||
struct vector_query_t {
|
||||
std::string field_name;
|
||||
size_t k = 0;
|
||||
bool exact = false;
|
||||
std::vector<float> values;
|
||||
|
||||
void _reset() {
|
||||
// used for testing only
|
||||
field_name.clear();
|
||||
k = 0;
|
||||
exact = false;
|
||||
values.clear();
|
||||
}
|
||||
};
|
||||
|
||||
class GeoPoint {
|
||||
constexpr static const double EARTH_RADIUS = 3958.75;
|
||||
constexpr static const double METER_CONVERT = 1609.00;
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "id_list.h"
|
||||
#include "synonym_index.h"
|
||||
#include "override.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
|
||||
static constexpr size_t ARRAY_FACET_DIM = 4;
|
||||
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
|
||||
@ -130,6 +131,8 @@ struct search_args {
|
||||
std::vector<std::vector<KV*>> raw_result_kvs;
|
||||
std::vector<std::vector<KV*>> override_result_kvs;
|
||||
|
||||
vector_query_t& vector_query;
|
||||
|
||||
search_args(std::vector<query_tokens_t> field_query_tokens, std::vector<search_field_t> search_fields,
|
||||
std::vector<filter> filters, std::vector<facet>& facets,
|
||||
std::vector<std::pair<uint32_t, uint32_t>>& included_ids, std::vector<uint32_t> excluded_ids,
|
||||
@ -142,7 +145,7 @@ struct search_args {
|
||||
size_t concurrency, size_t search_cutoff_ms,
|
||||
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<enable_t>& infixes,
|
||||
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens) :
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query) :
|
||||
field_query_tokens(field_query_tokens),
|
||||
search_fields(search_fields), filters(filters), facets(facets),
|
||||
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
|
||||
@ -156,7 +159,7 @@ struct search_args {
|
||||
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates),
|
||||
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
|
||||
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
|
||||
split_join_tokens(split_join_tokens) {
|
||||
split_join_tokens(split_join_tokens), vector_query(vector_query) {
|
||||
|
||||
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
|
||||
topster = new Topster(topster_size, group_limit);
|
||||
@ -229,6 +232,62 @@ struct index_record {
|
||||
}
|
||||
};
|
||||
|
||||
class VectorFilterFunctor: public hnswlib::FilterFunctor {
|
||||
const uint32_t* filter_ids = nullptr;
|
||||
const uint32_t filter_ids_length = 0;
|
||||
uint32 filter_ids_index = 0;
|
||||
|
||||
public:
|
||||
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
|
||||
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
|
||||
|
||||
bool operator()(unsigned int id) {
|
||||
if(filter_ids_length != 0) {
|
||||
if(filter_ids_index >= filter_ids_length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns iterator to the first element that is >= to value or last if no such element is found.
|
||||
size_t found_index = std::lower_bound(filter_ids + filter_ids_index,
|
||||
filter_ids + filter_ids_length, id) - filter_ids;
|
||||
|
||||
if(found_index == filter_ids_length) {
|
||||
// all elements are lesser than lowest value (id), so we can stop looking
|
||||
filter_ids_index = found_index + 1;
|
||||
return false;
|
||||
} else {
|
||||
if(filter_ids[found_index] == id) {
|
||||
filter_ids_index = found_index + 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
filter_ids_index = found_index;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct hnsw_index_t {
|
||||
hnswlib::L2Space* space;
|
||||
hnswlib::HierarchicalNSW<float, VectorFilterFunctor>* vecdex;
|
||||
size_t num_dim;
|
||||
|
||||
hnsw_index_t(size_t num_dim, size_t init_size): space(new hnswlib::L2Space(num_dim)),
|
||||
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
|
||||
num_dim(num_dim) {
|
||||
|
||||
}
|
||||
|
||||
~hnsw_index_t() {
|
||||
delete vecdex;
|
||||
delete space;
|
||||
}
|
||||
};
|
||||
|
||||
class Index {
|
||||
private:
|
||||
mutable std::shared_mutex mutex;
|
||||
@ -268,6 +327,9 @@ private:
|
||||
// infix field => value
|
||||
spp::sparse_hash_map<std::string, array_mapped_infix_t> infix_index;
|
||||
|
||||
// vector field => vector index
|
||||
spp::sparse_hash_map<std::string, hnsw_index_t*> vector_index;
|
||||
|
||||
// this is used for wildcard queries
|
||||
id_list_t* seq_ids;
|
||||
|
||||
@ -569,7 +631,8 @@ public:
|
||||
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
|
||||
size_t max_candidates, const std::vector<enable_t>& infixes, const size_t max_extra_prefix,
|
||||
const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, enable_t split_join_tokens) const;
|
||||
const bool filter_curated_hits, enable_t split_join_tokens,
|
||||
const vector_query_t& vector_query) const;
|
||||
|
||||
void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name);
|
||||
|
||||
|
@ -823,7 +823,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
const size_t max_extra_suffix,
|
||||
const size_t facet_query_num_typos,
|
||||
const size_t filter_curated_hits_option,
|
||||
const bool prioritize_token_position) const {
|
||||
const bool prioritize_token_position,
|
||||
const std::string& vector_query_str) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -871,6 +872,27 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
group_limit = 0;
|
||||
}
|
||||
|
||||
vector_query_t vector_query;
|
||||
if(!vector_query_str.empty()) {
|
||||
if(raw_query != "*") {
|
||||
return Option<nlohmann::json>(400, "Vector query is supported only on wildcard (q=*) searches.");
|
||||
}
|
||||
|
||||
if(!CollectionManager::parse_vector_query_str(vector_query_str, vector_query)) {
|
||||
return Option<nlohmann::json>(400, "The `vector_query` parameter is malformed.");
|
||||
}
|
||||
|
||||
auto vector_field_it = search_schema.find(vector_query.field_name);
|
||||
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
|
||||
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
|
||||
}
|
||||
|
||||
if(vector_field_it.value().num_dim != vector_query.values.size()) {
|
||||
return Option<nlohmann::json>(400, "Query field `" + vector_query.field_name + "` must have " +
|
||||
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
// validate search fields
|
||||
std::vector<std::string> processed_search_fields;
|
||||
|
||||
@ -1204,7 +1226,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
search_stop_millis,
|
||||
min_len_1typo, min_len_2typo, max_candidates, infixes,
|
||||
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
|
||||
filter_curated_hits, split_join_tokens);
|
||||
filter_curated_hits, split_join_tokens, vector_query);
|
||||
|
||||
index->run_search(search_params);
|
||||
|
||||
|
@ -608,6 +608,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const char *FACET_QUERY_NUM_TYPOS = "facet_query_num_typos";
|
||||
const char *MAX_FACET_VALUES = "max_facet_values";
|
||||
|
||||
const char *VECTOR_QUERY = "vector_query";
|
||||
|
||||
const char *GROUP_BY = "group_by";
|
||||
const char *GROUP_LIMIT = "group_limit";
|
||||
|
||||
@ -692,6 +694,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
size_t page = 1;
|
||||
token_ordering token_order = NOT_SET;
|
||||
|
||||
std::string vector_query;
|
||||
|
||||
std::vector<std::string> include_fields_vec;
|
||||
std::vector<std::string> exclude_fields_vec;
|
||||
spp::sparse_hash_set<std::string> include_fields;
|
||||
@ -747,6 +751,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
|
||||
std::unordered_map<std::string, std::string*> str_values = {
|
||||
{FILTER, &simple_filter_query},
|
||||
{VECTOR_QUERY, &vector_query},
|
||||
{FACET_QUERY, &simple_facet_query},
|
||||
{HIGHLIGHT_FIELDS, &highlight_fields},
|
||||
{HIGHLIGHT_FULL_FIELDS, &highlight_full_fields},
|
||||
@ -925,7 +930,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
max_extra_suffix,
|
||||
facet_query_num_typos,
|
||||
filter_curated_hits_option,
|
||||
prioritize_token_position
|
||||
prioritize_token_position,
|
||||
vector_query
|
||||
);
|
||||
|
||||
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
@ -1335,3 +1341,110 @@ Option<Collection*> CollectionManager::clone_collection(const string& existing_n
|
||||
|
||||
return Option<Collection*>(new_coll);
|
||||
}
|
||||
|
||||
bool CollectionManager::parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query) {
|
||||
// FORMAT:
|
||||
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
|
||||
size_t i = 0;
|
||||
while(i < vector_query_str.size()) {
|
||||
if(vector_query_str[i] != ':') {
|
||||
vector_query.field_name += vector_query_str[i];
|
||||
i++;
|
||||
} else {
|
||||
if(vector_query_str[i] != ':') {
|
||||
// missing ":"
|
||||
return false;
|
||||
}
|
||||
|
||||
// field name is done
|
||||
i++;
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '(') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '(') {
|
||||
// missing "("
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '[') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '[') {
|
||||
// missing opening "["
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::string values_str;
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != ']') {
|
||||
values_str += vector_query_str[i];
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != ']') {
|
||||
// missing closing "]"
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::vector<std::string> svalues;
|
||||
StringUtils::split(values_str, svalues, ",");
|
||||
|
||||
for(auto& svalue: svalues) {
|
||||
if(!StringUtils::is_float(svalue)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.values.push_back(std::stof(svalue));
|
||||
}
|
||||
|
||||
if(i == vector_query_str.size()-1) {
|
||||
// missing params
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string param_str = vector_query_str.substr(i, (vector_query_str.size() - i));
|
||||
std::vector<std::string> param_kvs;
|
||||
StringUtils::split(param_str, param_kvs, ",");
|
||||
|
||||
for(auto& param_kv_str: param_kvs) {
|
||||
if(param_kv_str.back() == ')') {
|
||||
param_kv_str.pop_back();
|
||||
}
|
||||
|
||||
std::vector<std::string> param_kv;
|
||||
StringUtils::split(param_kv_str, param_kv, ":");
|
||||
if(param_kv.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if(param_kv[0] == "k") {
|
||||
if(!StringUtils::is_uint32_t(param_kv[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.k = std::stoul(param_kv[1]);
|
||||
}
|
||||
|
||||
if(param_kv[0] == "exact") {
|
||||
if(!StringUtils::is_bool(param_kv[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.exact = (param_kv[1] == "true") ;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -508,6 +508,18 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
|
||||
field_json[fields::infix] = false;
|
||||
}
|
||||
|
||||
if(field_json.count(fields::num_dim) == 0) {
|
||||
field_json[fields::num_dim] = 0;
|
||||
} else {
|
||||
if(!field_json[fields::num_dim].is_number_unsigned() || field_json[fields::num_dim] == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::num_dim + "` must be a positive integer.");
|
||||
}
|
||||
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::num_dim + "` is only allowed on a float array field.");
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json.count(fields::optional) == 0) {
|
||||
// dynamic type fields are always optional
|
||||
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
|
||||
@ -531,10 +543,10 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
|
||||
}
|
||||
|
||||
the_fields.emplace_back(
|
||||
field(field_json[fields::name], field_json[fields::type], field_json[fields::facet],
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array])
|
||||
field(field_json[fields::name], field_json[fields::type], field_json[fields::facet],
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim])
|
||||
);
|
||||
|
||||
return Option<bool>(true);
|
||||
|
@ -116,6 +116,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
|
||||
infix_index.emplace(a_field.name, infix_sets);
|
||||
}
|
||||
|
||||
if(a_field.num_dim) {
|
||||
auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024);
|
||||
vector_index.emplace(a_field.name, hnsw_index);
|
||||
}
|
||||
}
|
||||
|
||||
num_documents = 0;
|
||||
@ -898,7 +903,7 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
} else if(afield.is_array()) {
|
||||
// all other numerical arrays
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, &vector_index=vector_index]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) {
|
||||
const auto& arr_value = record.doc[afield.name][arr_i];
|
||||
@ -924,6 +929,11 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
num_tree->insert(int64_t(value), seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
if(afield.type == field_types::FLOAT_ARRAY && afield.num_dim > 0) {
|
||||
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
|
||||
vector_index[afield.name]->vecdex->addPoint(float_vals.data(), (size_t)seq_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -1944,7 +1954,8 @@ void Index::run_search(search_args* search_params) {
|
||||
search_params->max_extra_suffix,
|
||||
search_params->facet_query_num_typos,
|
||||
search_params->filter_curated_hits,
|
||||
search_params->split_join_tokens);
|
||||
search_params->split_join_tokens,
|
||||
search_params->vector_query);
|
||||
}
|
||||
|
||||
void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
|
||||
@ -2373,7 +2384,8 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
|
||||
size_t max_candidates, const std::vector<enable_t>& infixes, const size_t max_extra_prefix,
|
||||
const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens) const {
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens,
|
||||
const vector_query_t& vector_query) const {
|
||||
|
||||
// process the filters
|
||||
|
||||
@ -2437,13 +2449,48 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
|
||||
curate_filtered_ids(filters, curated_ids, excluded_result_ids,
|
||||
excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted);
|
||||
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
|
||||
|
||||
search_wildcard(filters, included_ids_map, sort_fields_std, topster,
|
||||
curated_topster, groups_processed, searched_queries, group_limit, group_by_fields,
|
||||
curated_ids, curated_ids_sorted,
|
||||
excluded_result_ids, excluded_result_ids_size, field_id, field,
|
||||
all_result_ids, all_result_ids_len, filter_ids, filter_ids_length, concurrency,
|
||||
sort_order, field_values, geopoint_indices);
|
||||
if(!vector_query.field_name.empty()) {
|
||||
auto k = per_page * page;
|
||||
VectorFilterFunctor filterFunctor(filter_ids, filter_ids_length);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
auto dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
|
||||
std::vector<uint32_t> nearest_ids;
|
||||
|
||||
for(const auto& dist_label: dist_labels) {
|
||||
uint32 seq_id = dist_label.second;
|
||||
uint64_t distinct_id = seq_id;
|
||||
if(group_limit != 0) {
|
||||
distinct_id = get_distinct_id(group_by_fields, seq_id);
|
||||
groups_processed.emplace(distinct_id);
|
||||
}
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
scores[0] = -float_to_in64_t(dist_label.first);
|
||||
int64_t match_score_index = -1;
|
||||
|
||||
KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores);
|
||||
topster->add(&kv);
|
||||
nearest_ids.push_back(seq_id);
|
||||
}
|
||||
|
||||
if(!nearest_ids.empty()) {
|
||||
uint32_t* new_all_result_ids = nullptr;
|
||||
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, nearest_ids.data(),
|
||||
nearest_ids.size(), &new_all_result_ids);
|
||||
delete [] all_result_ids;
|
||||
all_result_ids = new_all_result_ids;
|
||||
}
|
||||
|
||||
} else {
|
||||
search_wildcard(filters, included_ids_map, sort_fields_std, topster,
|
||||
curated_topster, groups_processed, searched_queries, group_limit, group_by_fields,
|
||||
curated_ids, curated_ids_sorted,
|
||||
excluded_result_ids, excluded_result_ids_size, field_id, field,
|
||||
all_result_ids, all_result_ids_len, filter_ids, filter_ids_length, concurrency,
|
||||
sort_order, field_values, geopoint_indices);
|
||||
}
|
||||
} else {
|
||||
// Non-wildcard
|
||||
// In multi-field searches, a record can be matched across different fields, so we use this for aggregation
|
||||
@ -4114,8 +4161,6 @@ void Index::search_wildcard(const std::vector<filter>& filters,
|
||||
std::chrono::high_resolution_clock::now() - beginF).count();
|
||||
LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/
|
||||
|
||||
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
|
||||
|
||||
uint32_t* new_all_result_ids = nullptr;
|
||||
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids,
|
||||
filter_ids_length, &new_all_result_ids);
|
||||
@ -4858,6 +4903,10 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
int64_t fintval = float_to_in64_t(value);
|
||||
num_tree->remove(fintval, seq_id);
|
||||
}
|
||||
|
||||
if(search_field.num_dim) {
|
||||
vector_index[search_field.name]->vecdex->markDelete(seq_id);
|
||||
}
|
||||
} else if(search_field.is_bool()) {
|
||||
|
||||
const std::vector<bool>& values = search_field.is_single_bool() ?
|
||||
@ -5045,6 +5094,11 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
|
||||
infix_index.emplace(new_field.name, infix_sets);
|
||||
}
|
||||
|
||||
if(new_field.type == field_types::FLOAT_ARRAY && new_field.num_dim) {
|
||||
auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024);
|
||||
vector_index.emplace(new_field.name, hnsw_index);
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto & del_field: del_fields) {
|
||||
@ -5113,6 +5167,12 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
|
||||
infix_index.erase(del_field.name);
|
||||
}
|
||||
|
||||
if(del_field.num_dim) {
|
||||
auto hnsw_index = vector_index[del_field.name];
|
||||
delete hnsw_index;
|
||||
vector_index.erase(del_field.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -778,6 +778,43 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
|
||||
ASSERT_FALSE(sort_by_parsed);
|
||||
}
|
||||
|
||||
TEST_F(CollectionManagerTest, ParseVectorQueryString) {
|
||||
vector_query_t vector_query;
|
||||
bool parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
ASSERT_EQ("vec", vector_query.field_name);
|
||||
ASSERT_EQ(10, vector_query.k);
|
||||
std::vector<float> fvs = {0.34, 0.66, 0.12, 0.68};
|
||||
ASSERT_EQ(fvs.size(), vector_query.values.size());
|
||||
for(size_t i = 0; i < fvs.size(); i++) {
|
||||
ASSERT_EQ(fvs[i], vector_query.values[i]);
|
||||
}
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], exact: false, k: 10)", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], exact: false, k: 10", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, exact: false, k: 10)", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
}
|
||||
|
||||
TEST_F(CollectionManagerTest, Presets) {
|
||||
// try getting on a blank slate
|
||||
auto presets = collectionManager.get_presets();
|
||||
|
156
test/collection_vector_search_test.cpp
Normal file
156
test/collection_vector_search_test.cpp
Normal file
@ -0,0 +1,156 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
|
||||
class CollectionVectorTest : public ::testing::Test {
|
||||
protected:
|
||||
Store *store;
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
std::atomic<bool> quit = false;
|
||||
|
||||
std::vector<std::string> query_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/collection_vector_search";
|
||||
LOG(INFO) << "Truncating and creating: " << state_dir_path;
|
||||
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
|
||||
|
||||
store = new Store(state_dir_path);
|
||||
collectionManager.init(store, 1.0, "auth_key", quit);
|
||||
collectionManager.load(8, 1000);
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
collectionManager.dispose();
|
||||
delete store;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CollectionVectorTest, BasicVectorQuerying) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string"},
|
||||
{"name": "points", "type": "int32"},
|
||||
{"name": "vec", "type": "float[]", "num_dim": 4}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> values = {
|
||||
{0.851758, 0.909671, 0.823431, 0.372063},
|
||||
{0.97826, 0.933157, 0.39557, 0.306488},
|
||||
{0.230606, 0.634397, 0.514009, 0.399594}
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = std::to_string(i) + " title";
|
||||
doc["points"] = i;
|
||||
doc["vec"] = values[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// with filtering
|
||||
results = coll1->search("*", {}, "points:[0,1]", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(2, results["found"].get<size_t>());
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// validate wrong dimensions in query
|
||||
auto res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557])");
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Query field `vec` must have 4 dimensions.", res_op.error());
|
||||
|
||||
// validate bad vector query field name
|
||||
res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])");
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Field `zec` does not have a vector query index.", res_op.error());
|
||||
|
||||
// only supported with wildcard queries
|
||||
res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])");
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error());
|
||||
|
||||
// support num_dim on only float array fields
|
||||
schema = R"({
|
||||
"name": "coll2",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string"},
|
||||
{"name": "vec", "type": "float", "num_dim": 4}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
auto coll_op = collectionManager.create_collection(schema);
|
||||
ASSERT_FALSE(coll_op.ok());
|
||||
ASSERT_EQ("Property `num_dim` is only allowed on a float array field.", coll_op.error());
|
||||
|
||||
// bad value for num_dim
|
||||
schema = R"({
|
||||
"name": "coll2",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string"},
|
||||
{"name": "vec", "type": "float", "num_dim": -4}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
coll_op = collectionManager.create_collection(schema);
|
||||
ASSERT_FALSE(coll_op.ok());
|
||||
ASSERT_EQ("Property `num_dim` must be a positive integer.", coll_op.error());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user