diff --git a/src/index.cpp b/src/index.cpp index 6cf559f0..b2c838ed 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3204,19 +3204,18 @@ Option Index::search(std::vector& field_query_tokens, cons auto result = result_it->second; // old_score + (1 / rank_of_document) * WEIGHT) result->vector_distance = vec_result.second; - result->scores[result->match_score_index] = float_to_int64_t( + int64_t match_score = float_to_int64_t( (int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT)); + int64_t match_score_index = -1; + int64_t scores[3] = {0}; + + compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); - for(size_t i = 0;i < 3; i++) { - if(field_values[i] == &vector_distance_sentinel_value) { - result->scores[i] = float_to_int64_t(vec_result.second); - } - - if(sort_order[i] == -1) { - result->scores[i] = -result->scores[i]; - } + for(int i = 0; i < 3; i++) { + result->scores[i] = scores[i]; } + result->match_score_index = match_score_index; } else { // Result has been found only in vector search: we have to add it to both KV and result_ids diff --git a/test/collection_test.cpp b/test/collection_test.cpp index a1863b20..40b324f4 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5215,4 +5215,71 @@ TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) { ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API."); ASSERT_EQ(res["request"]["body"], req_body); +} + + +TEST_F(CollectionTest, HybridSearchSortByGeopoint) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "location", "type": "geopoint"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + + auto coll = op.get(); + + nlohmann::json doc; + doc["name"] = "butter"; + doc["location"] = {80.0, 150.0}; + + auto add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + + doc["name"] = "butterball"; + doc["location"] = {40.0, 100.0}; + + add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + + doc["name"] = "butterfly"; + doc["location"] = {130.0, 200.0}; + + add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + + + spp::sparse_hash_set dummy_include_exclude; + + std::vector sort_by_list = {{"location(10.0, 10.0)", "asc"}}; + + auto search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, sort_by_list, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10); + + ASSERT_TRUE(search_res_op.ok()); + + auto search_res = search_res_op.get(); + + ASSERT_EQ("butterfly", search_res["hits"][0]["document"]["name"].get()); + ASSERT_EQ("butterball", search_res["hits"][1]["document"]["name"].get()); + ASSERT_EQ("butter", search_res["hits"][2]["document"]["name"].get()); + + + search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10); + + ASSERT_TRUE(search_res_op.ok()); + + search_res = search_res_op.get(); + + + ASSERT_EQ("butter", search_res["hits"][0]["document"]["name"].get()); + ASSERT_EQ("butterball", search_res["hits"][1]["document"]["name"].get()); + ASSERT_EQ("butterfly", search_res["hits"][2]["document"]["name"].get()); } \ No newline at end of file