mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 05:32:30 +08:00
Use cosine similarity as default vector distance.
This commit is contained in:
parent
b20c32046b
commit
720855f406
@ -46,8 +46,14 @@ namespace fields {
|
||||
static const std::string nested = "nested";
|
||||
static const std::string nested_array = "nested_array";
|
||||
static const std::string num_dim = "num_dim";
|
||||
static const std::string vec_dist = "vec_dist";
|
||||
}
|
||||
|
||||
enum vector_distance_type_t {
|
||||
squared_l2,
|
||||
cosine
|
||||
};
|
||||
|
||||
struct field {
|
||||
std::string name;
|
||||
std::string type;
|
||||
@ -66,6 +72,7 @@ struct field {
|
||||
int nested_array;
|
||||
|
||||
size_t num_dim;
|
||||
vector_distance_type_t vec_dist;
|
||||
|
||||
static constexpr int VAL_UNKNOWN = 2;
|
||||
|
||||
@ -73,9 +80,9 @@ struct field {
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0) :
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist) {
|
||||
|
||||
if(sort != -1) {
|
||||
this->sort = bool(sort);
|
||||
|
@ -275,10 +275,12 @@ struct hnsw_index_t {
|
||||
hnswlib::L2Space* space;
|
||||
hnswlib::HierarchicalNSW<float, VectorFilterFunctor>* vecdex;
|
||||
size_t num_dim;
|
||||
vector_distance_type_t distance_type;
|
||||
|
||||
hnsw_index_t(size_t num_dim, size_t init_size): space(new hnswlib::L2Space(num_dim)),
|
||||
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
|
||||
num_dim(num_dim) {
|
||||
hnsw_index_t(size_t num_dim, size_t init_size, vector_distance_type_t distance_type):
|
||||
space(new hnswlib::L2Space(num_dim)),
|
||||
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
|
||||
num_dim(num_dim), distance_type(distance_type) {
|
||||
|
||||
}
|
||||
|
||||
@ -286,6 +288,18 @@ struct hnsw_index_t {
|
||||
delete vecdex;
|
||||
delete space;
|
||||
}
|
||||
|
||||
// needed for cosine similarity
|
||||
static void normalize_vector(const std::vector<float>& src, std::vector<float>& norm_dest) {
|
||||
float norm = 0.0f;
|
||||
for (float i : src) {
|
||||
norm += i * i;
|
||||
}
|
||||
norm = 1.0f / (sqrtf(norm) + 1e-30f);
|
||||
for (size_t i = 0; i < src.size(); i++) {
|
||||
norm_dest[i] = src[i] * norm;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Index {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <store.h>
|
||||
#include "field.h"
|
||||
#include "magic_enum.hpp"
|
||||
|
||||
Option<bool> filter::parse_geopoint_filter_value(std::string& raw_value,
|
||||
const std::string& format_err_msg,
|
||||
@ -508,8 +509,11 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
|
||||
field_json[fields::infix] = false;
|
||||
}
|
||||
|
||||
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
|
||||
|
||||
if(field_json.count(fields::num_dim) == 0) {
|
||||
field_json[fields::num_dim] = 0;
|
||||
field_json[fields::vec_dist] = DEFAULT_VEC_DIST_METRIC;
|
||||
} else {
|
||||
if(!field_json[fields::num_dim].is_number_unsigned() || field_json[fields::num_dim] == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::num_dim + "` must be a positive integer.");
|
||||
@ -518,6 +522,19 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::num_dim + "` is only allowed on a float array field.");
|
||||
}
|
||||
|
||||
if(field_json.count(fields::vec_dist) == 0) {
|
||||
field_json[fields::vec_dist] = DEFAULT_VEC_DIST_METRIC;
|
||||
} else {
|
||||
if(!field_json[fields::vec_dist].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::vec_dist + "` must be a string.");
|
||||
}
|
||||
|
||||
auto vec_dist_op = magic_enum::enum_cast<vector_distance_type_t>(field_json[fields::vec_dist].get<std::string>());
|
||||
if(!vec_dist_op.has_value()) {
|
||||
return Option<bool>(400, "Property `" + fields::vec_dist + "` is invalid.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json.count(fields::optional) == 0) {
|
||||
@ -542,11 +559,13 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
|
||||
field_json[fields::sort] = true;
|
||||
}
|
||||
|
||||
auto vec_dist = magic_enum::enum_cast<vector_distance_type_t>(field_json[fields::vec_dist].get<std::string>()).value();
|
||||
|
||||
the_fields.emplace_back(
|
||||
field(field_json[fields::name], field_json[fields::type], field_json[fields::facet],
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim])
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist)
|
||||
);
|
||||
|
||||
return Option<bool>(true);
|
||||
|
@ -118,7 +118,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
}
|
||||
|
||||
if(a_field.num_dim) {
|
||||
auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024);
|
||||
auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024, a_field.vec_dist);
|
||||
vector_index.emplace(a_field.name, hnsw_index);
|
||||
}
|
||||
}
|
||||
@ -931,13 +931,20 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
if(afield.type == field_types::FLOAT_ARRAY && afield.num_dim > 0) {
|
||||
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
|
||||
auto vec_index = vector_index[afield.name]->vecdex;
|
||||
size_t curr_ele_count = vec_index->getCurrentElementCount();
|
||||
if(curr_ele_count == vec_index->getMaxElements()) {
|
||||
vec_index->resizeIndex(curr_ele_count * 1.3);
|
||||
}
|
||||
vector_index[afield.name]->vecdex->addPoint(float_vals.data(), (size_t)seq_id);
|
||||
|
||||
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
|
||||
if(afield.vec_dist == cosine) {
|
||||
std::vector<float> normalized_vals(afield.num_dim);
|
||||
hnsw_index_t::normalize_vector(float_vals, normalized_vals);
|
||||
vector_index[afield.name]->vecdex->addPoint(normalized_vals.data(), (size_t)seq_id);
|
||||
} else {
|
||||
vector_index[afield.name]->vecdex->addPoint(float_vals.data(), (size_t)seq_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -2460,7 +2467,17 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
auto k = per_page * page;
|
||||
VectorFilterFunctor filterFunctor(filter_ids, filter_ids_length);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
auto dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
|
||||
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor);
|
||||
} else {
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> nearest_ids;
|
||||
|
||||
for(const auto& dist_label: dist_labels) {
|
||||
@ -5117,7 +5134,7 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
}
|
||||
|
||||
if(new_field.type == field_types::FLOAT_ARRAY && new_field.num_dim) {
|
||||
auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024);
|
||||
auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024, new_field.vec_dist);
|
||||
vector_index.emplace(new_field.name, hnsw_index);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user