diff --git a/WORKSPACE b/WORKSPACE index 098015ff..ce8521d2 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -160,8 +160,8 @@ new_git_repository( new_git_repository( name = "hnsw", build_file = "//bazel:hnsw.BUILD", - commit = "21de18ffabea1a9d1e8b16b49afc6045d7707e4c", - remote = "https://github.com/typesense/hnswlib.git", + commit = "359b2ba87358224963986f709e593d799064ace6", + remote = "https://github.com/nmslib/hnswlib.git", ) http_archive( diff --git a/include/index.h b/include/index.h index 332e9fa2..53c54995 100644 --- a/include/index.h +++ b/include/index.h @@ -232,7 +232,7 @@ struct index_record { } }; -class VectorFilterFunctor: public hnswlib::FilterFunctor { +class VectorFilterFunctor: public hnswlib::BaseFilterFunctor { const uint32_t* filter_ids = nullptr; const uint32_t filter_ids_length = 0; @@ -240,7 +240,7 @@ public: explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) : filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} - bool operator()(unsigned int id) { + bool operator()(hnswlib::labeltype id) override { if(filter_ids_length == 0) { return true; } @@ -251,13 +251,13 @@ public: struct hnsw_index_t { hnswlib::InnerProductSpace* space; - hnswlib::HierarchicalNSW* vecdex; + hnswlib::HierarchicalNSW* vecdex; size_t num_dim; vector_distance_type_t distance_type; hnsw_index_t(size_t num_dim, size_t init_size, vector_distance_type_t distance_type): space(new hnswlib::InnerProductSpace(num_dim)), - vecdex(new hnswlib::HierarchicalNSW(space, init_size, 16, 200, 100, true)), + vecdex(new hnswlib::HierarchicalNSW(space, init_size, 16, 200, 100, true)), num_dim(num_dim), distance_type(distance_type) { } diff --git a/src/index.cpp b/src/index.cpp index a24e6cb7..e6305315 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -850,9 +850,9 @@ void Index::index_field_in_memory(const field& afield, std::vector if(afield.vec_dist == cosine) { std::vector normalized_vals(afield.num_dim); hnsw_index_t::normalize_vector(float_vals, normalized_vals); - vec_index->insertPoint(normalized_vals.data(), (size_t)record.seq_id); + vec_index->addPoint(normalized_vals.data(), (size_t)record.seq_id, true); } else { - vec_index->insertPoint(float_vals.data(), (size_t)record.seq_id); + vec_index->addPoint(float_vals.data(), (size_t)record.seq_id, true); } } catch(const std::exception &e) { record.index_failure(400, e.what()); @@ -2866,9 +2866,9 @@ Option Index::search(std::vector& field_query_tokens, cons if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); hnsw_index_t::normalize_vector(vector_query.values, normalized_q); - dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor); + dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, &filterFunctor); } else { - dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor); + dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } } @@ -3113,9 +3113,9 @@ Option Index::search(std::vector& field_query_tokens, cons if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); hnsw_index_t::normalize_vector(vector_query.values, normalized_q); - dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor); + dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, &filterFunctor); } else { - dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor); + dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } std::vector> vec_results;