diff --git a/include/collection.h b/include/collection.h index 75bf51ae..aab3ce31 100644 --- a/include/collection.h +++ b/include/collection.h @@ -209,6 +209,8 @@ private: static Option parse_pinned_hits(const std::string& pinned_hits_str, std::map>& pinned_hits); + static Option parse_drop_tokens_mode(const std::string& drop_tokens_mode); + Index* init_index(); static std::vector to_char_array(const std::vector& strs); @@ -502,7 +504,7 @@ public: const std::string& stopwords_set="", const std::vector& facet_return_parent = {}, const std::vector& ref_include_fields_vec = {}, - const drop_tokens_mode_t drop_tokens_mode = right_to_left, + const std::string& drop_tokens_mode = "right_to_left", const bool prioritize_num_matching_fields = true, const bool group_missing_values = true) const; diff --git a/include/index.h b/include/index.h index 79c49c8b..2dbc88d6 100644 --- a/include/index.h +++ b/include/index.h @@ -101,6 +101,18 @@ enum text_match_type_t { enum drop_tokens_mode_t { left_to_right, right_to_left, + both_sides, +}; + +struct drop_tokens_param_t { + drop_tokens_mode_t mode = right_to_left; + size_t token_limit = 1000; + + drop_tokens_param_t() { + + } + + drop_tokens_param_t(drop_tokens_mode_t mode, size_t token_limit) : mode(mode), token_limit(token_limit) {} }; struct search_args { @@ -153,7 +165,7 @@ struct search_args { vector_query_t& vector_query; size_t facet_sample_percent; size_t facet_sample_threshold; - drop_tokens_mode_t drop_tokens_mode; + drop_tokens_param_t drop_tokens_mode; search_args(std::vector field_query_tokens, std::vector search_fields, const text_match_type_t match_type, @@ -170,7 +182,7 @@ struct search_args { size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query, - size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_mode_t drop_tokens_mode) : + size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_param_t drop_tokens_mode) : field_query_tokens(field_query_tokens), search_fields(search_fields), match_type(match_type), filter_tree_root(filter_tree_root), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), @@ -672,8 +684,10 @@ public: const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, enable_t split_join_tokens, const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, - const std::string& collection_name, facet_index_type_t facet_index_type = DETECT, - const drop_tokens_mode_t drop_tokens_mode = right_to_left) const; + const std::string& collection_name, + const drop_tokens_param_t drop_tokens_mode, + facet_index_type_t facet_index_type = DETECT + ) const; void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name, const bool is_update); diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 29bb36f6..b161bd3e 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -15,6 +15,7 @@ struct vector_query_t { uint32_t seq_id = 0; bool query_doc_given = false; + float alpha = 0.3; void _reset() { // used for testing only diff --git a/src/collection.cpp b/src/collection.cpp index 89f07fe2..bdb56b5c 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1419,7 +1419,7 @@ Option Collection::search(std::string raw_query, const std::string& stopwords_set, const std::vector& facet_return_parent, const std::vector& ref_include_fields_vec, - const drop_tokens_mode_t drop_tokens_mode, + const std::string& drop_tokens_mode, const bool prioritize_num_matching_fields, const bool group_missing_values) const { @@ -1779,6 +1779,13 @@ Option Collection::search(std::string raw_query, } } + Option drop_tokens_param_op = parse_drop_tokens_mode(drop_tokens_mode); + if(!drop_tokens_param_op.ok()) { + return Option(drop_tokens_param_op.code(), drop_tokens_param_op.error()); + } + + auto drop_tokens_param = drop_tokens_param_op.get(); + std::vector> raw_result_kvs; std::vector> override_result_kvs; @@ -1936,7 +1943,7 @@ Option Collection::search(std::string raw_query, min_len_1typo, min_len_2typo, max_candidates, infixes, max_extra_prefix, max_extra_suffix, facet_query_num_typos, filter_curated_hits, split_join_tokens, vector_query, - facet_sample_percent, facet_sample_threshold, drop_tokens_mode); + facet_sample_percent, facet_sample_threshold, drop_tokens_param); std::unique_ptr search_params_guard(search_params); @@ -4071,6 +4078,35 @@ Option Collection::parse_pinned_hits(const std::string& pinned_hits_str, return Option(true); } +Option Collection::parse_drop_tokens_mode(const std::string& drop_tokens_mode) { + drop_tokens_mode_t drop_tokens_mode_val = left_to_right; + size_t drop_tokens_token_limit = 1000; + auto drop_tokens_mode_op = magic_enum::enum_cast(drop_tokens_mode); + if(drop_tokens_mode_op.has_value()) { + drop_tokens_mode_val = drop_tokens_mode_op.value(); + } else { + std::vector drop_token_parts; + StringUtils::split(drop_tokens_mode, drop_token_parts, ":"); + if(drop_token_parts.size() == 2) { + if(!StringUtils::is_uint32_t(drop_token_parts[1])) { + return Option(400, "Invalid format for drop tokens mode."); + } + + drop_tokens_mode_op = magic_enum::enum_cast(drop_token_parts[0]); + if(drop_tokens_mode_op.has_value()) { + drop_tokens_mode_val = drop_tokens_mode_op.value(); + } + + drop_tokens_token_limit = std::stoul(drop_token_parts[1]); + + } else { + return Option(400, "Invalid format for drop tokens mode."); + } + } + + return Option(drop_tokens_param_t(drop_tokens_mode_val, drop_tokens_token_limit)); +} + Option Collection::add_synonym(const nlohmann::json& syn_json, bool write_to_store) { std::shared_lock lock(mutex); synonym_t synonym; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 107bc1a7..6129c63c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1415,12 +1415,6 @@ Option CollectionManager::do_search(std::map& re Index::NUM_CANDIDATES_DEFAULT_MIN); } - auto drop_tokens_mode_op = magic_enum::enum_cast(drop_tokens_mode_str); - drop_tokens_mode_t drop_tokens_mode; - if(drop_tokens_mode_op.has_value()) { - drop_tokens_mode = drop_tokens_mode_op.value(); - } - Option result_op = collection->search(raw_query, search_fields, filter_query, facet_fields, sort_fields, num_typos, per_page, @@ -1470,7 +1464,7 @@ Option CollectionManager::do_search(std::map& re stopwords_set, facet_return_parent, ref_include_fields_vec, - drop_tokens_mode, + drop_tokens_mode_str, prioritize_num_matching_fields, group_missing_values); diff --git a/src/index.cpp b/src/index.cpp index b4055da7..4f1877fc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1858,8 +1858,9 @@ Option Index::run_search(search_args* search_params, const std::string& co search_params->facet_sample_percent, search_params->facet_sample_threshold, collection_name, - facet_index_type, - search_params->drop_tokens_mode); + search_params->drop_tokens_mode, + facet_index_type + ); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -2310,8 +2311,9 @@ Option Index::search(std::vector& field_query_tokens, cons const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, const std::string& collection_name, - facet_index_type_t facet_index_type, - const drop_tokens_mode_t drop_tokens_mode) const { + const drop_tokens_param_t drop_tokens_mode, + facet_index_type_t facet_index_type + ) const { std::shared_lock lock(mutex); auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); @@ -2743,10 +2745,22 @@ Option Index::search(std::vector& field_query_tokens, cons for (size_t qi = 0; qi < all_queries.size(); qi++) { auto& orig_tokens = all_queries[qi]; size_t num_tokens_dropped = 0; - auto curr_direction = drop_tokens_mode; size_t total_dirs_done = 0; - while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) { + // NOTE: when dropping both sides we will ignore exhaustive search + + auto curr_direction = drop_tokens_mode.mode; + bool drop_both_sides = false; + + if(drop_tokens_mode.mode == both_sides) { + if(orig_tokens.size() <= drop_tokens_mode.token_limit) { + drop_both_sides = true; + } else { + curr_direction = right_to_left; + } + } + + while(exhaustive_search || all_result_ids_len < drop_tokens_threshold || drop_both_sides) { // When atleast two tokens from the query are available we can drop one std::vector truncated_tokens; std::vector dropped_tokens; @@ -2843,8 +2857,8 @@ Option Index::search(std::vector& field_query_tokens, cons if(has_text_match) { // For hybrid search, we need to give weight to text match and vector search - constexpr float TEXT_MATCH_WEIGHT = 0.7; - constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; + const float VECTOR_SEARCH_WEIGHT = vector_query.alpha; + const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT; VectorFilterFunctor filterFunctor(filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index a4773161..70a9bc4b 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -87,6 +87,13 @@ void Tokenizer::init(const std::string& input) { } unicode_text = icu::UnicodeString::fromUTF8(text); + + if(locale == "fa") { + icu::UnicodeString target_str; + target_str.setTo(0x200C); // U+200C (ZERO WIDTH NON-JOINER) + unicode_text.findAndReplace(target_str, " "); + } + bi->setText(unicode_text); start_pos = bi->first(); diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index 67443f2b..dba9d27d 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -156,6 +156,15 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_qu vector_query.distance_threshold = std::stof(param_kv[1]); } + + if(param_kv[0] == "alpha") { + if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 1.0) { + return Option(400, "Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0."); + } + + vector_query.alpha = std::stof(param_kv[1]); + } } return Option(true); diff --git a/test/collection_grouping_test.cpp b/test/collection_grouping_test.cpp index b98fd02f..8a69c55f 100644 --- a/test/collection_grouping_test.cpp +++ b/test/collection_grouping_test.cpp @@ -645,7 +645,7 @@ TEST_F(CollectionGroupingTest, ControlMissingValues) { {}, {}, {"brand"}, 2, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true, false).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, false).get(); ASSERT_EQ(3, res["grouped_hits"].size()); ASSERT_EQ("Omega", res["grouped_hits"][0]["group_key"][0].get()); @@ -668,7 +668,7 @@ TEST_F(CollectionGroupingTest, ControlMissingValues) { {}, {}, {"brand"}, 2, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true, true).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true).get(); ASSERT_EQ(2, res["grouped_hits"].size()); @@ -911,7 +911,7 @@ TEST_F(CollectionGroupingTest, SkipToReverseGroupBy) { {}, {}, {"brand"}, 2, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true, false).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, false).get(); ASSERT_EQ(1, res["grouped_hits"].size()); @@ -944,7 +944,7 @@ TEST_F(CollectionGroupingTest, SkipToReverseGroupBy) { {}, {}, {"brand"}, 2, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true, false).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, false).get(); ASSERT_EQ(5, res["grouped_hits"].size()); @@ -973,7 +973,7 @@ TEST_F(CollectionGroupingTest, SkipToReverseGroupBy) { {}, {}, {"brand"}, 2, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true, true).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true).get(); ASSERT_EQ(4, res["grouped_hits"].size()); diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index c89d761a..706a8827 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -1883,7 +1883,7 @@ TEST_F(CollectionSpecificMoreTest, DisableFieldCountForScoring) { spp::sparse_hash_set(), 10, "", 30, 4, "", 20, {}, {}, {}, 0, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true); auto res = coll1->search("beta", {"name", "brand"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5, @@ -1891,7 +1891,7 @@ TEST_F(CollectionSpecificMoreTest, DisableFieldCountForScoring) { spp::sparse_hash_set(), 10, "", 30, 4, "", 20, {}, {}, {}, 0, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, false).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", false).get(); size_t score1 = std::stoul(res["hits"][0]["text_match_info"]["score"].get()); size_t score2 = std::stoul(res["hits"][1]["text_match_info"]["score"].get()); @@ -1902,7 +1902,7 @@ TEST_F(CollectionSpecificMoreTest, DisableFieldCountForScoring) { spp::sparse_hash_set(), 10, "", 30, 4, "", 20, {}, {}, {}, 0, "", "", {3,3}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 0, 0, 0, 2, false, "", true, 0, max_score, - 100, 0, 0, HASH, 30000, 2, "", {}, {}, right_to_left, true).get(); + 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true).get(); ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); ASSERT_EQ("1", res["hits"][1]["document"]["id"].get()); @@ -2413,7 +2413,7 @@ TEST_F(CollectionSpecificMoreTest, DropTokensLeftToRightFirst) { spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, - 0, HASH, 30000, 2, "", {}, {}, left_to_right).get(); + 0, HASH, 30000, 2, "", {}, {}, "left_to_right").get(); ASSERT_EQ(1, res["hits"].size()); ASSERT_EQ("1", res["hits"][0]["document"]["id"].get()); @@ -2423,10 +2423,48 @@ TEST_F(CollectionSpecificMoreTest, DropTokensLeftToRightFirst) { spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, - 0, HASH, 30000, 2, "", {}, {}, right_to_left).get(); + 0, HASH, 30000, 2, "", {}, {}, "right_to_left").get(); ASSERT_EQ(1, res["hits"].size()); ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); + + // search on both sides + res = coll1->search("alpha gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "both_sides:3").get(); + ASSERT_EQ(2, res["hits"].size()); + + // but must follow token limit + res = coll1->search("alpha gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "both_sides:1").get(); + ASSERT_EQ(1, res["hits"].size()); + ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); + + // validation checks + auto res_op = coll1->search("alpha gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "all_sides"); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Invalid format for drop tokens mode.", res_op.error()); + + res_op = coll1->search("alpha gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "both_sides:x"); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Invalid format for drop tokens mode.", res_op.error()); } TEST_F(CollectionSpecificMoreTest, DoNotHighlightFieldsForSpecialCharacterQuery) { diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index c6bd963d..bfdb2f78 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2515,3 +2515,154 @@ TEST_F(CollectionVectorTest, TestUnloadModelsCollectionHaveTwoEmbeddingField) { text_embedders = TextEmbedderManager::get_instance()._get_text_embedders(); ASSERT_EQ(0, text_embedders.size()); } + +TEST_F(CollectionVectorTest, TestHybridSearchAlphaParam) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "embedding", + "type": "float[]", + "embed": { + "from": [ + "name" + ], + "model_config": { + "model_name": "ts/e5-small" + } + } + } + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + auto add_op = coll->add(R"({ + "name": "soccer" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "basketball" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "volleyball" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + + // do hybrid search + auto hybrid_results = coll->search("sports", {"name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(3, hybrid_results["hits"].size()); + + // check scores + ASSERT_FLOAT_EQ(0.3, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.15, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.10, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); + + // do hybrid search with alpha = 0.5 + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:0.5)").get(); + ASSERT_EQ(3, hybrid_results["hits"].size()); + + // check scores + ASSERT_FLOAT_EQ(0.5, hybrid_results["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.25, hybrid_results["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ(0.16666667, hybrid_results["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); +} + +TEST_F(CollectionVectorTest, TestHybridSearchInvalidAlpha) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "embedding", + "type": "float[]", + "embed": { + "from": [ + "name" + ], + "model_config": { + "model_name": "ts/e5-small" + } + } + } + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + + // do hybrid search with alpha = 1.5 + auto hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:1.5)"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + + // do hybrid search with alpha = -0.5 + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:-0.5)"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + + // do hybrid search with alpha as string + hybrid_results = coll->search("sports", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "embedding:([], alpha:\"0.5\")"); + + ASSERT_FALSE(hybrid_results.ok()); + ASSERT_EQ("Malformed vector query string: " + "`alpha` parameter must be a float between 0.0-1.0.", hybrid_results.error()); + +} diff --git a/test/tokenizer_test.cpp b/test/tokenizer_test.cpp index b141e97c..054df18b 100644 --- a/test/tokenizer_test.cpp +++ b/test/tokenizer_test.cpp @@ -323,6 +323,11 @@ TEST(TokenizerTest, ShouldTokenizeLocaleText) { tokens.clear(); // 配管 Tokenizer("配管", true, false, "ja").tokenize(tokens); + + // persian containing zwnj + tokens.clear(); + Tokenizer("روان\u200Cشناسی", false, false, "fa").tokenize(tokens); + ASSERT_EQ(2, tokens.size()); } TEST(TokenizerTest, ShouldTokenizeLocaleTextWithEnglishText) {