diff --git a/include/override.h b/include/override.h index 344f32ae..a92c01c2 100644 --- a/include/override.h +++ b/include/override.h @@ -9,6 +9,7 @@ struct override_t { struct rule_t { std::string query; + std::string normalized_query; // not actually stored, used for lowercasing etc. std::string match; bool dynamic_query = false; std::string filter_by; diff --git a/include/tokenizer.h b/include/tokenizer.h index 3700a7d3..a6a88ab6 100644 --- a/include/tokenizer.h +++ b/include/tokenizer.h @@ -87,4 +87,6 @@ public: void decr_token_counter(); bool should_skip_char(char c); + + static void normalize_ascii(std::string& text); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index d6aa3403..e4737326 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -652,9 +652,9 @@ void Collection::curate_results(string& actual_query, const string& filter_query bool filter_by_match = (override.rule.query.empty() && override.rule.match.empty() && !override.rule.filter_by.empty() && override.rule.filter_by == filter_query); - bool query_match = (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || + bool query_match = (override.rule.match == override_t::MATCH_EXACT && override.rule.normalized_query == query) || (override.rule.match == override_t::MATCH_CONTAINS && - StringUtils::contains_word(query, override.rule.query)); + StringUtils::contains_word(query, override.rule.normalized_query)); if (filter_by_match || query_match) { if(!override.rule.filter_by.empty() && override.rule.filter_by != filter_query) { @@ -686,7 +686,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query actual_query = override.replace_query; } else if(override.remove_matched_tokens && override.filter_by.empty()) { // don't prematurely remove tokens from query because dynamic filtering will require them - StringUtils::replace_all(query, override.rule.query, ""); + StringUtils::replace_all(query, override.rule.normalized_query, ""); StringUtils::trim(query); if(query.empty()) { query = "*"; diff --git a/src/index.cpp b/src/index.cpp index 6e99d0d0..21a80cad 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2308,9 +2308,9 @@ bool Index::static_filter_query_eval(const override_t* override, filter_node_t*& filter_tree_root) const { std::string query = StringUtils::join(tokens, " "); - if ((override->rule.match == override_t::MATCH_EXACT && override->rule.query == query) || + if ((override->rule.match == override_t::MATCH_EXACT && override->rule.normalized_query == query) || (override->rule.match == override_t::MATCH_CONTAINS && - StringUtils::contains_word(query, override->rule.query))) { + StringUtils::contains_word(query, override->rule.normalized_query))) { filter_node_t* new_filter_tree_root = nullptr; Option filter_op = filter::parse_filter_query(override->filter_by, search_schema, store, "", new_filter_tree_root); @@ -2453,7 +2453,7 @@ void Index::process_filter_overrides(const std::vector& filte // we will cover both original query and synonyms std::vector rule_parts; - StringUtils::split(override->rule.query, rule_parts, " "); + StringUtils::split(override->rule.normalized_query, rule_parts, " "); uint32_t* field_override_ids = nullptr; size_t field_override_ids_len = 0; diff --git a/src/override.cpp b/src/override.cpp index efe0002b..96f96d38 100644 --- a/src/override.cpp +++ b/src/override.cpp @@ -1,5 +1,6 @@ #include #include "override.h" +#include "tokenizer.h" Option override_t::parse(const nlohmann::json& override_json, const std::string& id, override_t& override) { if(!override_json.is_object()) { @@ -108,6 +109,11 @@ Option override_t::parse(const nlohmann::json& override_json, const std::s override.rule.query = json_rule.count("query") == 0 ? "" : json_rule["query"].get(); override.rule.match = json_rule.count("match") == 0 ? "" : json_rule["match"].get(); + if(!override.rule.query.empty()) { + override.rule.normalized_query = override.rule.query; + Tokenizer::normalize_ascii(override.rule.normalized_query); + } + if(json_rule.count("filter_by") != 0) { if(!override_json["rule"]["filter_by"].is_string()) { return Option(400, "Override `rule.filter_by` must be a string."); @@ -172,15 +178,15 @@ Option override_t::parse(const nlohmann::json& override_json, const std::s // we have to also detect if it is a dynamic query rule size_t i = 0; - while(i < override.rule.query.size()) { - if(override.rule.query[i] == '{') { + while(i < override.rule.normalized_query.size()) { + if(override.rule.normalized_query[i] == '{') { // look for closing curly i++; - while(i < override.rule.query.size()) { - if(override.rule.query[i] == '}') { + while(i < override.rule.normalized_query.size()) { + if(override.rule.normalized_query[i] == '}') { override.rule.dynamic_query = true; // remove spaces around curlies - override.rule.query = StringUtils::trim_curly_spaces(override.rule.query); + override.rule.normalized_query = StringUtils::trim_curly_spaces(override.rule.normalized_query); break; } i++; diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 8488b24a..45192787 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -338,3 +338,11 @@ void Tokenizer::decr_token_counter() { bool Tokenizer::should_skip_char(char c) { return is_ascii_char(c) && get_stream_mode(c) != INDEX; } + +void Tokenizer::normalize_ascii(std::string& text) { + for(size_t i = 0; i < text.size(); i++) { + if(is_ascii_char(text[i])) { + text[i] = std::tolower(text[i]); + } + } +} diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 48ea30c5..f3ca16a1 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -875,6 +875,79 @@ TEST_F(CollectionOverrideTest, ReplaceQuery) { ASSERT_TRUE(op.ok()); } +TEST_F(CollectionOverrideTest, RuleQueryMustBeCaseInsensitive) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["points"] = 30; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Tennis Ball"; + doc2["points"] = 50; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Golf Ball"; + doc3["points"] = 1; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "GrEat", + "match": "contains" + }, + "replace_query": "amazing" + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + override_json = R"({ + "id": "rule-2", + "rule": { + "query": "BaLL", + "match": "contains" + }, + "filter_by": "points: 1" + })"_json; + + override_t override_rule2; + op = override_t::parse(override_json, "rule-2", override_rule2); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule2); + + auto results = coll1->search("great shoes", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + + results = coll1->search("ball", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); +} + TEST_F(CollectionOverrideTest, WindowForRule) { Collection *coll1;