From 70738115466d9b49b929e6be12d4f62faa4afa82 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Mon, 9 Oct 2023 23:20:35 +0300 Subject: [PATCH 1/2] Add vector query param to set hybrid saerch alpha --- include/vector_query_ops.h | 1 + src/index.cpp | 4 +- src/vector_query_ops.cpp | 9 +++ test/collection_vector_search_test.cpp | 78 +++++++++++++++++++++++++- 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 29bb36f6..b161bd3e 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -15,6 +15,7 @@ struct vector_query_t { uint32_t seq_id = 0; bool query_doc_given = false; + float alpha = 0.3; void _reset() { // used for testing only diff --git a/src/index.cpp b/src/index.cpp index 5059fd8d..8dc5ebb8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3178,8 +3178,8 @@ Option Index::search(std::vector& field_query_tokens, cons if(has_text_match) { // For hybrid search, we need to give weight to text match and vector search - constexpr float TEXT_MATCH_WEIGHT = 0.7; - constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; + const float VECTOR_SEARCH_WEIGHT = vector_query.alpha; + const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT; VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index 67443f2b..dba9d27d 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -156,6 +156,15 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_qu vector_query.distance_threshold = std::stof(param_kv[1]); } + + if(param_kv[0] == "alpha") { + if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 1.0) { + return Option(400, "Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0."); + } + + vector_query.alpha = std::stof(param_kv[1]); + } } return Option(true); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 55049641..cdb682f8 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2497,4 +2497,80 @@ TEST_F(CollectionVectorTest, TestUnloadModelsCollectionHaveTwoEmbeddingField) { text_embedders = TextEmbedderManager::get_instance()._get_text_embedders(); ASSERT_EQ(0, text_embedders.size()); -} \ No newline at end of file +} + +TEST_F(CollectionVectorTest, TestHybridSearchAlphaParam) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "embedding", + "type": "float[]", + "embed": { + "from": [ + "name" + ], + "model_config": { + "model_name": "ts/e5-small" + } + } + } + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + auto add_op = coll->add(R"({ + "name": "soccer" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "basketball" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "volleyball" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + + // do hybrid search + auto hybrid_results = coll->search("sports", {"name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(3, hybrid_results["hits"].size()); + + // check scores + ASSERT_FLOAT_EQ(0.3, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.15, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.10, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); + + // do hybrid search with alpha = 0.5 + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:0.5)").get(); + ASSERT_EQ(3, hybrid_results["hits"].size()); + + // check scores + ASSERT_FLOAT_EQ(0.5, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.25, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.16666667, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); +} \ No newline at end of file From 998b071956a412887035f76b75a0700149615b88 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Tue, 10 Oct 2023 11:20:15 +0300 Subject: [PATCH 2/2] Add test for invalid alpha params --- test/collection_vector_search_test.cpp | 77 +++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index cdb682f8..b22686d1 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2573,4 +2573,79 @@ TEST_F(CollectionVectorTest, TestHybridSearchAlphaParam) { ASSERT_FLOAT_EQ(0.5, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); ASSERT_FLOAT_EQ(0.25, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); ASSERT_FLOAT_EQ(0.16666667, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); -} \ No newline at end of file +} + +TEST_F(CollectionVectorTest, TestHybridSearchInvalidAlpha) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "embedding", + "type": "float[]", + "embed": { + "from": [ + "name" + ], + "model_config": { + "model_name": "ts/e5-small" + } + } + } + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + + // do hybrid search with alpha = 1.5 + auto hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:1.5)"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + + // do hybrid search with alpha = -0.5 + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:-0.5)"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + + // do hybrid search with alpha as string + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:\"0.5\")"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + +} \ No newline at end of file