mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 06:40:30 +08:00
Merge pull request #1292 from ozanarmagan/v0.25-join
Add vector query param to set hybrid saerch alpha
This commit is contained in:
commit
d64f7f2cfe
@ -15,6 +15,7 @@ struct vector_query_t {
|
||||
|
||||
uint32_t seq_id = 0;
|
||||
bool query_doc_given = false;
|
||||
float alpha = 0.3;
|
||||
|
||||
void _reset() {
|
||||
// used for testing only
|
||||
|
@ -3188,8 +3188,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
|
||||
if(has_text_match) {
|
||||
// For hybrid search, we need to give weight to text match and vector search
|
||||
constexpr float TEXT_MATCH_WEIGHT = 0.7;
|
||||
constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT;
|
||||
const float VECTOR_SEARCH_WEIGHT = vector_query.alpha;
|
||||
const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT;
|
||||
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
|
@ -156,6 +156,15 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
|
||||
|
||||
vector_query.distance_threshold = std::stof(param_kv[1]);
|
||||
}
|
||||
|
||||
if(param_kv[0] == "alpha") {
|
||||
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 1.0) {
|
||||
return Option<bool>(400, "Malformed vector query string: "
|
||||
"`alpha` parameter must be a float between 0.0-1.0.");
|
||||
}
|
||||
|
||||
vector_query.alpha = std::stof(param_kv[1]);
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
|
@ -2497,4 +2497,155 @@ TEST_F(CollectionVectorTest, TestUnloadModelsCollectionHaveTwoEmbeddingField) {
|
||||
|
||||
text_embedders = TextEmbedderManager::get_instance()._get_text_embedders();
|
||||
ASSERT_EQ(0, text_embedders.size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, TestHybridSearchAlphaParam) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": [
|
||||
"name"
|
||||
],
|
||||
"model_config": {
|
||||
"model_name": "ts/e5-small"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
auto add_op = coll->add(R"({
|
||||
"name": "soccer"
|
||||
})"_json.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "basketball"
|
||||
})"_json.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "volleyball"
|
||||
})"_json.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
|
||||
// do hybrid search
|
||||
auto hybrid_results = coll->search("sports", {"name", "embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
ASSERT_EQ(3, hybrid_results["hits"].size());
|
||||
|
||||
// check scores
|
||||
ASSERT_FLOAT_EQ(0.3, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.15, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.10, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
|
||||
// do hybrid search with alpha = 0.5
|
||||
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "embedding:([], alpha:0.5)").get();
|
||||
ASSERT_EQ(3, hybrid_results["hits"].size());
|
||||
|
||||
// check scores
|
||||
ASSERT_FLOAT_EQ(0.5, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.25, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.16666667, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, TestHybridSearchInvalidAlpha) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": [
|
||||
"name"
|
||||
],
|
||||
"model_config": {
|
||||
"model_name": "ts/e5-small"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
|
||||
// do hybrid search with alpha = 1.5
|
||||
auto hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "embedding:([], alpha:1.5)");
|
||||
|
||||
ASSERT_FALSE(hybrid_results.ok());
|
||||
ASSERT_EQ("Malformed vector query string: "
|
||||
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
|
||||
|
||||
// do hybrid search with alpha = -0.5
|
||||
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "embedding:([], alpha:-0.5)");
|
||||
|
||||
ASSERT_FALSE(hybrid_results.ok());
|
||||
ASSERT_EQ("Malformed vector query string: "
|
||||
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
|
||||
|
||||
// do hybrid search with alpha as string
|
||||
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "embedding:([], alpha:\"0.5\")");
|
||||
|
||||
ASSERT_FALSE(hybrid_results.ok());
|
||||
ASSERT_EQ("Malformed vector query string: "
|
||||
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user