mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 13:12:22 +08:00
Return vector distance in response.
This commit is contained in:
parent
31559f15b2
commit
c7f879bf30
@ -50,7 +50,7 @@ namespace fields {
|
||||
}
|
||||
|
||||
enum vector_distance_type_t {
|
||||
squared_l2,
|
||||
ip,
|
||||
cosine
|
||||
};
|
||||
|
||||
|
@ -585,7 +585,9 @@ public:
|
||||
static int get_bounded_typo_cost(const size_t max_cost, const size_t token_len,
|
||||
size_t min_len_1typo, size_t min_len_2typo);
|
||||
|
||||
static int64_t float_to_in64_t(float n);
|
||||
static int64_t float_to_int64_t(float n);
|
||||
|
||||
static float int64_t_to_float(int64_t n);
|
||||
|
||||
uint64_t get_distinct_id(const std::vector<std::string>& group_by_fields, const uint32_t seq_id) const;
|
||||
|
||||
|
@ -1601,6 +1601,10 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
wrapper_doc["geo_distance_meters"] = geo_distances;
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty()) {
|
||||
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
|
||||
}
|
||||
|
||||
hits_array.push_back(wrapper_doc);
|
||||
}
|
||||
|
||||
|
@ -220,7 +220,7 @@ int64_t Index::get_points_from_doc(const nlohmann::json &document, const std::st
|
||||
return points;
|
||||
}
|
||||
|
||||
int64_t Index::float_to_in64_t(float f) {
|
||||
int64_t Index::float_to_int64_t(float f) {
|
||||
// https://stackoverflow.com/questions/60530255/convert-float-to-int64-t-while-preserving-ordering
|
||||
int32_t i;
|
||||
memcpy(&i, &f, sizeof i);
|
||||
@ -230,6 +230,18 @@ int64_t Index::float_to_in64_t(float f) {
|
||||
return i;
|
||||
}
|
||||
|
||||
float Index::int64_t_to_float(int64_t n) {
|
||||
int32_t i = (int32_t) n;
|
||||
|
||||
if(i < 0) {
|
||||
i ^= INT32_MAX;
|
||||
}
|
||||
|
||||
float f;
|
||||
memcpy(&f, &i, sizeof f);
|
||||
return f;
|
||||
}
|
||||
|
||||
void Index::compute_token_offsets_facets(index_record& record,
|
||||
const tsl::htrie_map<char, field>& search_schema,
|
||||
const std::vector<char>& local_token_separators,
|
||||
@ -846,7 +858,7 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
float fvalue = record.doc[afield.name].get<float>();
|
||||
int64_t value = float_to_in64_t(fvalue);
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
num_tree->insert(value, seq_id);
|
||||
});
|
||||
} else if(afield.type == field_types::BOOL) {
|
||||
@ -924,7 +936,7 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
else if(afield.type == field_types::FLOAT_ARRAY) {
|
||||
const float fvalue = arr_value;
|
||||
int64_t value = float_to_in64_t(fvalue);
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
num_tree->insert(value, seq_id);
|
||||
}
|
||||
|
||||
@ -977,7 +989,7 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
if(is_integer) {
|
||||
doc_to_score->emplace(seq_id, document[afield.name].get<int64_t>());
|
||||
} else if(is_float) {
|
||||
int64_t ifloat = float_to_in64_t(document[afield.name].get<float>());
|
||||
int64_t ifloat = float_to_int64_t(document[afield.name].get<float>());
|
||||
doc_to_score->emplace(seq_id, ifloat);
|
||||
} else if(is_bool) {
|
||||
doc_to_score->emplace(seq_id, (int64_t) document[afield.name].get<bool>());
|
||||
@ -1646,11 +1658,11 @@ void Index::do_filtering(uint32_t*& filter_ids, uint32_t& filter_ids_length,
|
||||
for(size_t fi=0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string & filter_value = a_filter.values[fi];
|
||||
float value = (float) std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = float_to_in64_t(value);
|
||||
int64_t float_int64 = float_to_int64_t(value);
|
||||
|
||||
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = float_to_in64_t((float) std::atof(next_filter_value.c_str()));
|
||||
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
|
||||
fi++;
|
||||
} else {
|
||||
@ -2561,7 +2573,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
}
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
scores[0] = -float_to_in64_t(dist_label.first);
|
||||
scores[0] = -float_to_int64_t(dist_label.first);
|
||||
int64_t match_score_index = -1;
|
||||
|
||||
KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores);
|
||||
@ -5012,7 +5024,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
document[field_name].get<std::vector<float>>();
|
||||
for(float value: values) {
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
int64_t fintval = float_to_in64_t(value);
|
||||
int64_t fintval = float_to_int64_t(value);
|
||||
num_tree->remove(fintval, seq_id);
|
||||
if(search_field.facet) {
|
||||
remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id);
|
||||
|
@ -75,6 +75,10 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
|
||||
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
ASSERT_FLOAT_EQ(3.409385681152344e-05, results["hits"][0]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.04329806566238403, results["hits"][1]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.15141665935516357, results["hits"][2]["vector_distance"].get<float>());
|
||||
|
||||
// with filtering
|
||||
results = coll1->search("*", {}, "points:[0,1]", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
|
@ -166,7 +166,7 @@ TEST(TopsterTest, MaxFloatValues) {
|
||||
for(int i = 0; i < 12; i++) {
|
||||
int64_t scores[3];
|
||||
scores[0] = int64_t(data[i].match_score);
|
||||
scores[1] = Index::float_to_in64_t(data[i].primary_attr);
|
||||
scores[1] = Index::float_to_int64_t(data[i].primary_attr);
|
||||
scores[2] = data[i].secondary_attr;
|
||||
|
||||
KV kv(data[i].field_id, data[i].query_index, data[i].token_bits, data[i].key, data[i].key, 0, scores);
|
||||
|
Loading…
x
Reference in New Issue
Block a user