Use cosine similarity as default vector distance.

This commit is contained in:
Kishore Nallan 2022-09-02 17:48:02 +05:30
parent b20c32046b
commit 720855f406
4 changed files with 68 additions and 11 deletions

View File

@ -46,8 +46,14 @@ namespace fields {
static const std::string nested = "nested";
static const std::string nested_array = "nested_array";
static const std::string num_dim = "num_dim";
static const std::string vec_dist = "vec_dist";
}
enum vector_distance_type_t {
squared_l2,
cosine
};
struct field {
std::string name;
std::string type;
@ -66,6 +72,7 @@ struct field {
int nested_array;
size_t num_dim;
vector_distance_type_t vec_dist;
static constexpr int VAL_UNKNOWN = 2;
@ -73,9 +80,9 @@ struct field {
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0) :
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim) {
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist) {
if(sort != -1) {
this->sort = bool(sort);

View File

@ -275,10 +275,12 @@ struct hnsw_index_t {
hnswlib::L2Space* space;
hnswlib::HierarchicalNSW<float, VectorFilterFunctor>* vecdex;
size_t num_dim;
vector_distance_type_t distance_type;
hnsw_index_t(size_t num_dim, size_t init_size): space(new hnswlib::L2Space(num_dim)),
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
num_dim(num_dim) {
hnsw_index_t(size_t num_dim, size_t init_size, vector_distance_type_t distance_type):
space(new hnswlib::L2Space(num_dim)),
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
num_dim(num_dim), distance_type(distance_type) {
}
@ -286,6 +288,18 @@ struct hnsw_index_t {
delete vecdex;
delete space;
}
// needed for cosine similarity
static void normalize_vector(const std::vector<float>& src, std::vector<float>& norm_dest) {
float norm = 0.0f;
for (float i : src) {
norm += i * i;
}
norm = 1.0f / (sqrtf(norm) + 1e-30f);
for (size_t i = 0; i < src.size(); i++) {
norm_dest[i] = src[i] * norm;
}
}
};
class Index {

View File

@ -1,5 +1,6 @@
#include <store.h>
#include "field.h"
#include "magic_enum.hpp"
Option<bool> filter::parse_geopoint_filter_value(std::string& raw_value,
const std::string& format_err_msg,
@ -508,8 +509,11 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
field_json[fields::infix] = false;
}
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
if(field_json.count(fields::num_dim) == 0) {
field_json[fields::num_dim] = 0;
field_json[fields::vec_dist] = DEFAULT_VEC_DIST_METRIC;
} else {
if(!field_json[fields::num_dim].is_number_unsigned() || field_json[fields::num_dim] == 0) {
return Option<bool>(400, "Property `" + fields::num_dim + "` must be a positive integer.");
@ -518,6 +522,19 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
return Option<bool>(400, "Property `" + fields::num_dim + "` is only allowed on a float array field.");
}
if(field_json.count(fields::vec_dist) == 0) {
field_json[fields::vec_dist] = DEFAULT_VEC_DIST_METRIC;
} else {
if(!field_json[fields::vec_dist].is_string()) {
return Option<bool>(400, "Property `" + fields::vec_dist + "` must be a string.");
}
auto vec_dist_op = magic_enum::enum_cast<vector_distance_type_t>(field_json[fields::vec_dist].get<std::string>());
if(!vec_dist_op.has_value()) {
return Option<bool>(400, "Property `" + fields::vec_dist + "` is invalid.");
}
}
}
if(field_json.count(fields::optional) == 0) {
@ -542,11 +559,13 @@ Option<bool> field::json_field_to_field(nlohmann::json& field_json, std::vector<
field_json[fields::sort] = true;
}
auto vec_dist = magic_enum::enum_cast<vector_distance_type_t>(field_json[fields::vec_dist].get<std::string>()).value();
the_fields.emplace_back(
field(field_json[fields::name], field_json[fields::type], field_json[fields::facet],
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim])
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist)
);
return Option<bool>(true);

View File

@ -118,7 +118,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
}
if(a_field.num_dim) {
auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024);
auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024, a_field.vec_dist);
vector_index.emplace(a_field.name, hnsw_index);
}
}
@ -931,13 +931,20 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
}
if(afield.type == field_types::FLOAT_ARRAY && afield.num_dim > 0) {
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
auto vec_index = vector_index[afield.name]->vecdex;
size_t curr_ele_count = vec_index->getCurrentElementCount();
if(curr_ele_count == vec_index->getMaxElements()) {
vec_index->resizeIndex(curr_ele_count * 1.3);
}
vector_index[afield.name]->vecdex->addPoint(float_vals.data(), (size_t)seq_id);
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
if(afield.vec_dist == cosine) {
std::vector<float> normalized_vals(afield.num_dim);
hnsw_index_t::normalize_vector(float_vals, normalized_vals);
vector_index[afield.name]->vecdex->addPoint(normalized_vals.data(), (size_t)seq_id);
} else {
vector_index[afield.name]->vecdex->addPoint(float_vals.data(), (size_t)seq_id);
}
}
});
}
@ -2460,7 +2467,17 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
auto k = per_page * page;
VectorFilterFunctor filterFunctor(filter_ids, filter_ids_length);
auto& field_vector_index = vector_index.at(vector_query.field_name);
auto dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
std::vector<std::pair<float, size_t>> dist_labels;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
}
std::vector<uint32_t> nearest_ids;
for(const auto& dist_label: dist_labels) {
@ -5117,7 +5134,7 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
}
if(new_field.type == field_types::FLOAT_ARRAY && new_field.num_dim) {
auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024);
auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024, new_field.vec_dist);
vector_index.emplace(new_field.name, hnsw_index);
}
}