From 15114a6c87566d96b71c86351b0909e897edbde5 Mon Sep 17 00:00:00 2001 From: Krunal Gandhi Date: Fri, 5 Apr 2024 08:55:28 +0000 Subject: [PATCH] add boolean for enabling typos for alphanumeric tokens (#1651) Co-authored-by: Kishore Nallan --- include/collection.h | 6 +- include/index.h | 21 ++++--- src/collection.cpp | 17 ++++-- src/collection_manager.cpp | 6 +- src/index.cpp | 43 ++++++++++---- test/collection_specific_more_test.cpp | 81 ++++++++++++++++++++++++++ 6 files changed, 146 insertions(+), 28 deletions(-) diff --git a/include/collection.h b/include/collection.h index 32c5e50f..458104f1 100644 --- a/include/collection.h +++ b/include/collection.h @@ -278,7 +278,8 @@ private: std::vector>& included_ids, std::vector& excluded_ids, nlohmann::json& override_metadata, - bool enable_typos_for_numerical_tokens=true) const; + bool enable_typos_for_numerical_tokens=true, + bool enable_typos_for_alpha_numerical_tokens=true) const; void populate_text_match_info(nlohmann::json& info, uint64_t match_score, const text_match_type_t match_type, const size_t total_tokens) const; @@ -588,7 +589,8 @@ public: bool enable_synonyms = true, bool synonym_prefix = false, uint32_t synonym_num_typos = 0, - bool enable_lazy_filter = false) const; + bool enable_lazy_filter = false, + bool enable_typos_for_alpha_numerical_tokens = true) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/index.h b/include/index.h index 17793bfc..36985e4c 100644 --- a/include/index.h +++ b/include/index.h @@ -468,13 +468,15 @@ private: const std::vector& query_tokens, token_ordering token_order, std::set& absorbed_tokens, std::string& filter_by_clause, - bool enable_typos_for_numerical_tokens) const; + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const; bool check_for_overrides(const token_ordering& token_order, const string& field_name, bool slide_window, bool exact_rule_match, std::vector& tokens, std::set& absorbed_tokens, std::vector& field_absorbed_tokens, - bool enable_typos_for_numerical_tokens) const; + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const; static void aggregate_topster(Topster* agg_topster, Topster* index_topster); @@ -638,7 +640,8 @@ public: facet_index_t* _get_facet_index() const; static int get_bounded_typo_cost(const size_t max_cost, const std::string& token, const size_t token_len, - size_t min_len_1typo, size_t min_len_2typo, bool enable_typos_for_numerical_tokens=true); + size_t min_len_1typo, size_t min_len_2typo, bool enable_typos_for_numerical_tokens=true, + bool enable_typos_for_alpha_numerical_tokens = true); static int64_t float_to_int64_t(float n); @@ -663,7 +666,8 @@ public: Option run_search(search_args* search_params, const std::string& collection_name, facet_index_type_t facet_index_type, bool enable_typos_for_numerical_tokens, - bool enable_synonyms, bool synonym_prefix, uint32_t synonym_num_typos); + bool enable_synonyms, bool synonym_prefix, uint32_t synonym_num_typos, + bool enable_typos_for_alpha_numerical_tokens); Option search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, @@ -697,7 +701,8 @@ public: bool enable_synonyms = true, bool synonym_prefix = false, uint32_t synonym_num_typos = 0, - bool enable_lazy_filter = false + bool enable_lazy_filter = false, + bool enable_typos_for_alpha_numerical_tokens = true ) const; void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name, @@ -920,7 +925,8 @@ public: std::array*, 3>& field_values, const std::vector& geopoint_indices, const std::string& collection_name = "", - bool enable_typos_for_numerical_tokens = true) const; + bool enable_typos_for_numerical_tokens = true, + bool enable_typos_for_alpha_numerical_tokens = true) const; void find_across_fields(const token_t& previous_token, const std::string& previous_token_str, @@ -991,7 +997,8 @@ public: filter_node_t*& filter_tree_root, std::vector& matched_dynamic_overrides, nlohmann::json& override_metadata, - bool enable_typos_for_numerical_tokens) const; + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const; Option compute_sort_scores(const std::vector& sort_fields, const int* sort_order, std::array*, 3> field_values, diff --git a/src/collection.cpp b/src/collection.cpp index 2c41393b..e6e9fe66 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1763,7 +1763,8 @@ Option Collection::search(std::string raw_query, bool enable_synonyms, bool synonym_prefix, uint32_t synonyms_num_typos, - bool enable_lazy_filter) const { + bool enable_lazy_filter, + bool enable_typos_for_alpha_numerical_tokens) const { std::shared_lock lock(mutex); // setup thread local vars @@ -2293,7 +2294,8 @@ Option Collection::search(std::string raw_query, false, stopwords_set); process_filter_overrides(filter_overrides, q_include_tokens, token_order, filter_tree_root, - included_ids, excluded_ids, override_metadata, enable_typos_for_numerical_tokens); + included_ids, excluded_ids, override_metadata, enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); for(size_t i = 0; i < q_include_tokens.size(); i++) { auto& q_include_token = q_include_tokens[i]; @@ -2314,7 +2316,8 @@ Option Collection::search(std::string raw_query, // included_ids, excluded_ids process_filter_overrides(filter_overrides, q_include_tokens, token_order, filter_tree_root, - included_ids, excluded_ids, override_metadata, enable_typos_for_numerical_tokens); + included_ids, excluded_ids, override_metadata, enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); for(size_t i = 0; i < q_include_tokens.size(); i++) { auto& q_include_token = q_include_tokens[i]; @@ -2360,7 +2363,7 @@ Option Collection::search(std::string raw_query, auto search_op = index->run_search(search_params, name, facet_index_type, enable_typos_for_numerical_tokens, enable_synonyms, synonym_prefix, - synonyms_num_typos); + synonyms_num_typos, enable_typos_for_alpha_numerical_tokens); // filter_tree_root might be updated in Index::static_filter_query_eval. filter_tree_root_guard.release(); @@ -3395,12 +3398,14 @@ void Collection::process_filter_overrides(std::vector& filter std::vector>& included_ids, std::vector& excluded_ids, nlohmann::json& override_metadata, - bool enable_typos_for_numerical_tokens) const { + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const { std::vector matched_dynamic_overrides; index->process_filter_overrides(filter_overrides, q_include_tokens, token_order, filter_tree_root, matched_dynamic_overrides, override_metadata, - enable_typos_for_numerical_tokens); + enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); // we will check the dynamic overrides to see if they also have include/exclude std::set excluded_set; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index bd330d0c..c5b37c74 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1479,6 +1479,7 @@ Option CollectionManager::do_search(std::map& re const char *VOICE_QUERY = "voice_query"; const char *ENABLE_TYPOS_FOR_NUMERICAL_TOKENS = "enable_typos_for_numerical_tokens"; + const char *ENABLE_TYPOS_FOR_ALPHA_NUMERICAL_TOKENS = "enable_typos_for_alpha_numerical_tokens"; const char *ENABLE_LAZY_FILTER = "enable_lazy_filter"; const char *SYNONYM_PREFIX = "synonym_prefix"; @@ -1607,6 +1608,7 @@ Option CollectionManager::do_search(std::map& re bool enable_highlight_v1 = true; text_match_type_t match_type = max_score; bool enable_typos_for_numerical_tokens = true; + bool enable_typos_for_alpha_numerical_tokens = true; bool enable_lazy_filter = Config::get_instance().get_enable_lazy_filter(); size_t remote_embedding_timeout_ms = 5000; @@ -1684,6 +1686,7 @@ Option CollectionManager::do_search(std::map& re {ENABLE_SYNONYMS, &enable_synonyms}, {SYNONYM_PREFIX, &synonym_prefix}, {ENABLE_LAZY_FILTER, &enable_lazy_filter}, + {ENABLE_TYPOS_FOR_ALPHA_NUMERICAL_TOKENS, &enable_typos_for_alpha_numerical_tokens}, }; std::unordered_map*> str_list_values = { @@ -1902,7 +1905,8 @@ Option CollectionManager::do_search(std::map& re enable_synonyms, synonym_prefix, synonym_num_typos, - enable_lazy_filter); + enable_lazy_filter, + enable_typos_for_alpha_numerical_tokens); uint64_t timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); diff --git a/src/index.cpp b/src/index.cpp index d7a5ead6..73319d21 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2212,7 +2212,8 @@ Option Index::do_filtering_with_reference_ids(const std::string Option Index::run_search(search_args* search_params, const std::string& collection_name, facet_index_type_t facet_index_type, bool enable_typos_for_numerical_tokens, - bool enable_synonyms, bool synonym_prefix, uint32_t synonym_num_typos) { + bool enable_synonyms, bool synonym_prefix, uint32_t synonym_num_typos, + bool enable_typos_for_alpha_numerical_tokens) { return search(search_params->field_query_tokens, search_params->search_fields, search_params->match_type, @@ -2257,7 +2258,8 @@ Option Index::run_search(search_args* search_params, const std::string& co enable_synonyms, synonym_prefix, synonym_num_typos, - search_params->enable_lazy_filter + search_params->enable_lazy_filter, + enable_typos_for_alpha_numerical_tokens ); } @@ -2346,7 +2348,8 @@ bool Index::static_filter_query_eval(const override_t* override, bool Index::resolve_override(const std::vector& rule_tokens, const bool exact_rule_match, const std::vector& query_tokens, token_ordering token_order, std::set& absorbed_tokens, - std::string& filter_by_clause, bool enable_typos_for_numerical_tokens) const { + std::string& filter_by_clause, bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const { bool resolved_override = false; size_t i = 0, j = 0; @@ -2389,7 +2392,8 @@ bool Index::resolve_override(const std::vector& rule_tokens, const std::vector field_absorbed_tokens; resolved_override &= check_for_overrides(token_order, field_name, slide_window, exact_rule_match, matched_tokens, absorbed_tokens, - field_absorbed_tokens, enable_typos_for_numerical_tokens); + field_absorbed_tokens, enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); if(!resolved_override) { goto RETURN_EARLY; @@ -2441,7 +2445,8 @@ void Index::process_filter_overrides(const std::vector& filte filter_node_t*& filter_tree_root, std::vector& matched_dynamic_overrides, nlohmann::json& override_metadata, - bool enable_typos_for_numerical_tokens) const { + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const { std::shared_lock lock(mutex); for (auto& override : filter_overrides) { @@ -2478,7 +2483,8 @@ void Index::process_filter_overrides(const std::vector& filte std::set absorbed_tokens; bool resolved_override = resolve_override(rule_parts, exact_rule_match, query_tokens, token_order, absorbed_tokens, filter_by_clause, - enable_typos_for_numerical_tokens); + enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); if (resolved_override) { if(override_metadata.empty()) { @@ -2536,7 +2542,8 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& bool exact_rule_match, std::vector& tokens, std::set& absorbed_tokens, std::vector& field_absorbed_tokens, - bool enable_typos_for_numerical_tokens) const { + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const { for(size_t window_len = tokens.size(); window_len > 0; window_len--) { for(size_t start_index = 0; start_index+window_len-1 < tokens.size(); start_index++) { @@ -2748,7 +2755,8 @@ Option Index::search(std::vector& field_query_tokens, cons bool enable_typos_for_numerical_tokens, bool enable_synonyms, bool synonym_prefix, uint32_t synonym_num_typos, - bool enable_lazy_filter) const { + bool enable_lazy_filter, + bool enable_typos_for_alpha_numerical_tokens) const { std::shared_lock lock(mutex); auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root, @@ -3145,7 +3153,8 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, - collection_name, enable_typos_for_numerical_tokens); + collection_name, enable_typos_for_numerical_tokens, + enable_typos_for_alpha_numerical_tokens); if (!fuzzy_search_fields_op.ok()) { return fuzzy_search_fields_op; } @@ -3924,7 +3933,8 @@ Option Index::fuzzy_search_fields(const std::vector& the_f std::array*, 3>& field_values, const std::vector& geopoint_indices, const std::string& collection_name, - bool enable_typos_for_numerical_tokens) const { + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) const { // NOTE: `query_tokens` preserve original tokens, while `search_tokens` could be a result of dropped tokens @@ -3939,7 +3949,7 @@ Option Index::fuzzy_search_fields(const std::vector& the_f std::vector all_costs; // This ensures that we don't end up doing a cost of 1 for a single char etc. int bounded_cost = get_bounded_typo_cost(2, token , token.length(), min_len_1typo, min_len_2typo, - enable_typos_for_numerical_tokens); + enable_typos_for_numerical_tokens, enable_typos_for_alpha_numerical_tokens); for(int cost = 0; cost <= bounded_cost; cost++) { all_costs.push_back(cost); @@ -6166,7 +6176,16 @@ void Index::populate_sort_mapping_with_lock(int* sort_order, std::vector int Index::get_bounded_typo_cost(const size_t max_cost, const std::string& token, const size_t token_len, const size_t min_len_1typo, const size_t min_len_2typo, - bool enable_typos_for_numerical_tokens) { + bool enable_typos_for_numerical_tokens, + bool enable_typos_for_alpha_numerical_tokens) { + + if(!enable_typos_for_alpha_numerical_tokens) { + for(auto c : token) { + if(!isalnum(c)) { //some special char which is indexed + return 0; + } + } + } if(!enable_typos_for_numerical_tokens && std::all_of(token.begin(), token.end(), ::isdigit)) { return 0; diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 3f696605..d42024d8 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -3012,4 +3012,85 @@ TEST_F(CollectionSpecificMoreTest, TestFieldStore) { ASSERT_EQ(1, res.get()["hits"].size()); ASSERT_EQ("store", res.get()["hits"][0]["document"]["word_to_store"].get()); ASSERT_TRUE(res.get()["hits"][0]["document"].count("word_not_to_store") == 0); +} + +TEST_F(CollectionSpecificMoreTest, EnableTyposForAlphaNumericalTokens) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"} + ], + "symbols_to_index":["/"] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["title"] = "c-136/14"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["title"] = "13/14"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["title"] = "(136)214"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["title"] = "c136/14"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["title"] = "A-136/14"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + bool enable_typos_for_alpha_numerical_tokens = false; + + auto res = coll1->search("c-136/14", {"title"}, "", {}, + {}, {2}, 10, 1,FREQUENCY, {true}, + Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", + 30, 4, "", 40, + {}, {}, {}, 0,"", + "", {}, 1000,true, + false, true, "", false, + 6000*1000, 4, 7, fallback, 4, + {off}, INT16_MAX, INT16_MAX,2, + 2, false, "", true, + 0, max_score, 100, 0, 0, + HASH, 30000, 2, "", + {},{}, "right_to_left", true, + true, false, "", "", "", + "", true, true, false, 0, true, + enable_typos_for_alpha_numerical_tokens).get(); + + ASSERT_EQ(2, res["hits"].size()); + + ASSERT_EQ("c136/14", res["hits"][0]["document"]["title"].get()); + ASSERT_EQ("c-136/14", res["hits"][1]["document"]["title"].get()); + + enable_typos_for_alpha_numerical_tokens = true; + + res = coll1->search("c-136/14", {"title"}, "", {}, + {}, {2}, 10, 1,FREQUENCY, {true}, + Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", + 30, 4, "", 40, + {}, {}, {}, 0,"", + "", {}, 1000,true, + false, true, "", false, + 6000*1000, 4, 7, fallback, 4, + {off}, INT16_MAX, INT16_MAX,2, + 2, false, "", true, + 0, max_score, 100, 0, 0, + HASH, 30000, 2, "", + {},{}, "right_to_left", true, + true, false, "", "", "", + "", true, true, false, 0, true, + enable_typos_for_alpha_numerical_tokens).get(); + + ASSERT_EQ(5, res["hits"].size()); + + ASSERT_EQ("c136/14", res["hits"][0]["document"]["title"].get()); + ASSERT_EQ("c-136/14", res["hits"][1]["document"]["title"].get()); + ASSERT_EQ("A-136/14", res["hits"][2]["document"]["title"].get()); + ASSERT_EQ("(136)214", res["hits"][3]["document"]["title"].get()); + ASSERT_EQ("13/14", res["hits"][4]["document"]["title"].get()); } \ No newline at end of file