diff --git a/include/field.h b/include/field.h index 4b948ae7..5ceaf909 100644 --- a/include/field.h +++ b/include/field.h @@ -50,7 +50,7 @@ namespace fields { } enum vector_distance_type_t { - squared_l2, + ip, cosine }; diff --git a/include/index.h b/include/index.h index 20ee3bd5..26ad1da1 100644 --- a/include/index.h +++ b/include/index.h @@ -585,7 +585,9 @@ public: static int get_bounded_typo_cost(const size_t max_cost, const size_t token_len, size_t min_len_1typo, size_t min_len_2typo); - static int64_t float_to_in64_t(float n); + static int64_t float_to_int64_t(float n); + + static float int64_t_to_float(int64_t n); uint64_t get_distinct_id(const std::vector& group_by_fields, const uint32_t seq_id) const; diff --git a/src/collection.cpp b/src/collection.cpp index f08a4d69..db0afa35 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1601,6 +1601,10 @@ Option Collection::search(const std::string & raw_query, wrapper_doc["geo_distance_meters"] = geo_distances; } + if(!vector_query.field_name.empty()) { + wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]); + } + hits_array.push_back(wrapper_doc); } diff --git a/src/index.cpp b/src/index.cpp index b740e7ae..355e2821 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -220,7 +220,7 @@ int64_t Index::get_points_from_doc(const nlohmann::json &document, const std::st return points; } -int64_t Index::float_to_in64_t(float f) { +int64_t Index::float_to_int64_t(float f) { // https://stackoverflow.com/questions/60530255/convert-float-to-int64-t-while-preserving-ordering int32_t i; memcpy(&i, &f, sizeof i); @@ -230,6 +230,18 @@ int64_t Index::float_to_in64_t(float f) { return i; } +float Index::int64_t_to_float(int64_t n) { + int32_t i = (int32_t) n; + + if(i < 0) { + i ^= INT32_MAX; + } + + float f; + memcpy(&f, &i, sizeof f); + return f; +} + void Index::compute_token_offsets_facets(index_record& record, const tsl::htrie_map& search_schema, const std::vector& local_token_separators, @@ -846,7 +858,7 @@ void Index::index_field_in_memory(const field& afield, std::vector iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { float fvalue = record.doc[afield.name].get(); - int64_t value = float_to_in64_t(fvalue); + int64_t value = float_to_int64_t(fvalue); num_tree->insert(value, seq_id); }); } else if(afield.type == field_types::BOOL) { @@ -924,7 +936,7 @@ void Index::index_field_in_memory(const field& afield, std::vector else if(afield.type == field_types::FLOAT_ARRAY) { const float fvalue = arr_value; - int64_t value = float_to_in64_t(fvalue); + int64_t value = float_to_int64_t(fvalue); num_tree->insert(value, seq_id); } @@ -977,7 +989,7 @@ void Index::index_field_in_memory(const field& afield, std::vector if(is_integer) { doc_to_score->emplace(seq_id, document[afield.name].get()); } else if(is_float) { - int64_t ifloat = float_to_in64_t(document[afield.name].get()); + int64_t ifloat = float_to_int64_t(document[afield.name].get()); doc_to_score->emplace(seq_id, ifloat); } else if(is_bool) { doc_to_score->emplace(seq_id, (int64_t) document[afield.name].get()); @@ -1646,11 +1658,11 @@ void Index::do_filtering(uint32_t*& filter_ids, uint32_t& filter_ids_length, for(size_t fi=0; fi < a_filter.values.size(); fi++) { const std::string & filter_value = a_filter.values[fi]; float value = (float) std::atof(filter_value.c_str()); - int64_t float_int64 = float_to_in64_t(value); + int64_t float_int64 = float_to_int64_t(value); if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = float_to_in64_t((float) std::atof(next_filter_value.c_str())); + int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str())); num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len); fi++; } else { @@ -2561,7 +2573,7 @@ void Index::search(std::vector& field_query_tokens, const std::v } int64_t scores[3] = {0}; - scores[0] = -float_to_in64_t(dist_label.first); + scores[0] = -float_to_int64_t(dist_label.first); int64_t match_score_index = -1; KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores); @@ -5012,7 +5024,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const document[field_name].get>(); for(float value: values) { num_tree_t* num_tree = numerical_index.at(field_name); - int64_t fintval = float_to_in64_t(value); + int64_t fintval = float_to_int64_t(value); num_tree->remove(fintval, seq_id); if(search_field.facet) { remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index cf8f8bbf..4d76c1cb 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -75,6 +75,10 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) { ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_FLOAT_EQ(3.409385681152344e-05, results["hits"][0]["vector_distance"].get()); + ASSERT_FLOAT_EQ(0.04329806566238403, results["hits"][1]["vector_distance"].get()); + ASSERT_FLOAT_EQ(0.15141665935516357, results["hits"][2]["vector_distance"].get()); + // with filtering results = coll1->search("*", {}, "points:[0,1]", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), diff --git a/test/topster_test.cpp b/test/topster_test.cpp index 858448c3..7420c896 100644 --- a/test/topster_test.cpp +++ b/test/topster_test.cpp @@ -166,7 +166,7 @@ TEST(TopsterTest, MaxFloatValues) { for(int i = 0; i < 12; i++) { int64_t scores[3]; scores[0] = int64_t(data[i].match_score); - scores[1] = Index::float_to_in64_t(data[i].primary_attr); + scores[1] = Index::float_to_int64_t(data[i].primary_attr); scores[2] = data[i].secondary_attr; KV kv(data[i].field_id, data[i].query_index, data[i].token_bits, data[i].key, data[i].key, 0, scores);