Fix sorting by geopoint when using hybrid search

This commit is contained in:
ozanarmagan 2023-07-25 14:16:53 +03:00
parent c299172ccc
commit 4e76c4780c
2 changed files with 75 additions and 9 deletions

View File

@ -3204,19 +3204,18 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto result = result_it->second;
// old_score + (1 / rank_of_document) * WEIGHT)
result->vector_distance = vec_result.second;
result->scores[result->match_score_index] = float_to_int64_t(
int64_t match_score = float_to_int64_t(
(int64_t_to_float(result->scores[result->match_score_index])) +
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
int64_t match_score_index = -1;
int64_t scores[3] = {0};
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
for(size_t i = 0;i < 3; i++) {
if(field_values[i] == &vector_distance_sentinel_value) {
result->scores[i] = float_to_int64_t(vec_result.second);
}
if(sort_order[i] == -1) {
result->scores[i] = -result->scores[i];
}
for(int i = 0; i < 3; i++) {
result->scores[i] = scores[i];
}
result->match_score_index = match_score_index;
} else {
// Result has been found only in vector search: we have to add it to both KV and result_ids

View File

@ -5215,4 +5215,71 @@ TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) {
ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API.");
ASSERT_EQ(res["request"]["body"], req_body);
}
TEST_F(CollectionTest, HybridSearchSortByGeopoint) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "location", "type": "geopoint"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
auto coll = op.get();
nlohmann::json doc;
doc["name"] = "butter";
doc["location"] = {80.0, 150.0};
auto add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
doc["name"] = "butterball";
doc["location"] = {40.0, 100.0};
add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
doc["name"] = "butterfly";
doc["location"] = {130.0, 200.0};
add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
spp::sparse_hash_set<std::string> dummy_include_exclude;
std::vector<sort_by> sort_by_list = {{"location(10.0, 10.0)", "asc"}};
auto search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, sort_by_list, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10);
ASSERT_TRUE(search_res_op.ok());
auto search_res = search_res_op.get();
ASSERT_EQ("butterfly", search_res["hits"][0]["document"]["name"].get<std::string>());
ASSERT_EQ("butterball", search_res["hits"][1]["document"]["name"].get<std::string>());
ASSERT_EQ("butter", search_res["hits"][2]["document"]["name"].get<std::string>());
search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10);
ASSERT_TRUE(search_res_op.ok());
search_res = search_res_op.get();
ASSERT_EQ("butter", search_res["hits"][0]["document"]["name"].get<std::string>());
ASSERT_EQ("butterball", search_res["hits"][1]["document"]["name"].get<std::string>());
ASSERT_EQ("butterfly", search_res["hits"][2]["document"]["name"].get<std::string>());
}