Merge pull request #1292 from ozanarmagan/v0.25-join

Add vector query param to set hybrid saerch alpha
This commit is contained in:
Kishore Nallan 2023-10-10 15:30:29 +05:30 committed by GitHub
commit d64f7f2cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 2 deletions

View File

@ -15,6 +15,7 @@ struct vector_query_t {
uint32_t seq_id = 0;
bool query_doc_given = false;
float alpha = 0.3;
void _reset() {
// used for testing only

View File

@ -3188,8 +3188,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
if(has_text_match) {
// For hybrid search, we need to give weight to text match and vector search
constexpr float TEXT_MATCH_WEIGHT = 0.7;
constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT;
const float VECTOR_SEARCH_WEIGHT = vector_query.alpha;
const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT;
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
auto& field_vector_index = vector_index.at(vector_query.field_name);

View File

@ -156,6 +156,15 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
vector_query.distance_threshold = std::stof(param_kv[1]);
}
if(param_kv[0] == "alpha") {
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 1.0) {
return Option<bool>(400, "Malformed vector query string: "
"`alpha` parameter must be a float between 0.0-1.0.");
}
vector_query.alpha = std::stof(param_kv[1]);
}
}
return Option<bool>(true);

View File

@ -2497,4 +2497,155 @@ TEST_F(CollectionVectorTest, TestUnloadModelsCollectionHaveTwoEmbeddingField) {
text_embedders = TextEmbedderManager::get_instance()._get_text_embedders();
ASSERT_EQ(0, text_embedders.size());
}
TEST_F(CollectionVectorTest, TestHybridSearchAlphaParam) {
nlohmann::json schema = R"({
"name": "test",
"fields": [
{
"name": "name",
"type": "string"
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"name"
],
"model_config": {
"model_name": "ts/e5-small"
}
}
}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
auto add_op = coll->add(R"({
"name": "soccer"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "basketball"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "volleyball"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
// do hybrid search
auto hybrid_results = coll->search("sports", {"name", "embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
ASSERT_EQ(3, hybrid_results["hits"].size());
// check scores
ASSERT_FLOAT_EQ(0.3, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
ASSERT_FLOAT_EQ(0.15, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
ASSERT_FLOAT_EQ(0.10, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
// do hybrid search with alpha = 0.5
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "embedding:([], alpha:0.5)").get();
ASSERT_EQ(3, hybrid_results["hits"].size());
// check scores
ASSERT_FLOAT_EQ(0.5, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
ASSERT_FLOAT_EQ(0.25, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
ASSERT_FLOAT_EQ(0.16666667, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
}
TEST_F(CollectionVectorTest, TestHybridSearchInvalidAlpha) {
nlohmann::json schema = R"({
"name": "test",
"fields": [
{
"name": "name",
"type": "string"
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"name"
],
"model_config": {
"model_name": "ts/e5-small"
}
}
}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
// do hybrid search with alpha = 1.5
auto hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "embedding:([], alpha:1.5)");
ASSERT_FALSE(hybrid_results.ok());
ASSERT_EQ("Malformed vector query string: "
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
// do hybrid search with alpha = -0.5
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "embedding:([], alpha:-0.5)");
ASSERT_FALSE(hybrid_results.ok());
ASSERT_EQ("Malformed vector query string: "
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
// do hybrid search with alpha as string
hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "embedding:([], alpha:\"0.5\")");
ASSERT_FALSE(hybrid_results.ok());
ASSERT_EQ("Malformed vector query string: "
"`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error());
}