diff --git a/include/string_utils.h b/include/string_utils.h index 728bcdab..f497718c 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -319,4 +319,6 @@ struct StringUtils { static std::string trim_curly_spaces(const std::string& str); static bool ends_with(std::string const &str, std::string const &ending); + + static bool contains_word(const std::string& haystack, const std::string& needle); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index e91af441..9357d48c 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -431,8 +431,9 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo continue; } - if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || - (override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) { + if ((override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || + (override.rule.match == override_t::MATCH_CONTAINS && + StringUtils::contains_word(query, override.rule.query))) { // have to ensure that dropped hits take precedence over added hits for(const auto & hit: override.drop_hits) { diff --git a/src/index.cpp b/src/index.cpp index 10a11fa5..a41500b8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1153,12 +1153,12 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, //LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold; //LOG(INFO) << "n: " << n; - /*std::stringstream fullq; + std::stringstream fullq; for(const auto& qleaf : actual_query_suggestion) { std::string qtok(reinterpret_cast(qleaf->key),qleaf->key_len - 1); fullq << qtok << " "; } - LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;*/ + LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost; // Prepare excluded document IDs that we can later remove from the result set uint32_t* excluded_result_ids = nullptr; @@ -1735,8 +1735,9 @@ bool Index::static_filter_query_eval(const override_t* override, std::string query = StringUtils::join(tokens, " "); - if( (override->rule.match == override_t::MATCH_EXACT && override->rule.query == query) || - (override->rule.match == override_t::MATCH_CONTAINS && query.find(override->rule.query) != std::string::npos) ) { + if ((override->rule.match == override_t::MATCH_EXACT && override->rule.query == query) || + (override->rule.match == override_t::MATCH_CONTAINS && + StringUtils::contains_word(query, override->rule.query))) { Option filter_op = filter::parse_filter_query(override->filter_by, search_schema, store, "", filters); diff --git a/src/string_utils.cpp b/src/string_utils.cpp index f542273c..221bc2eb 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -291,6 +291,28 @@ bool StringUtils::ends_with(const std::string& str, const std::string& ending) { } } +bool StringUtils::contains_word(const std::string& haystack, const std::string& needle) { + size_t pos = haystack.find(needle); + if(pos == std::string::npos) { + return false; + } + + if(pos == 0 && haystack.size() == needle.size()) { + return true; + } + + if(pos != 0 && haystack[pos - 1] != ' ') { + return false; + } + + size_t end_pos = pos + needle.size(); + if(end_pos < haystack.size() and haystack[end_pos] != ' ') { + return false; + } + + return true; +} + /*size_t StringUtils::unicode_length(const std::string& bytes) { std::wstring_convert, char32_t> utf8conv; return utf8conv.from_bytes(bytes).size(); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 86193ea8..f5ee891b 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -171,6 +171,14 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeExactQueryMatch) { ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get().c_str()); ASSERT_STREQ("1", results["hits"][3]["document"]["id"].get().c_str()); + // partial word should not match + res_op = coll_mul_fields->search("dowillow", {"title"}, "", {}, {}, {0}, 10); + ASSERT_TRUE(res_op.ok()); + results = res_op.get(); + + ASSERT_EQ(0, results["hits"].size()); + ASSERT_EQ(0, results["found"].get()); + // ability to disable overrides bool enable_overrides = false; res_op = coll_mul_fields->search("will", {"title"}, "", {}, {}, {0}, 10, @@ -1531,6 +1539,13 @@ TEST_F(CollectionOverrideTest, StaticFiltering) { ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + // partial word should not match + results = coll1->search("inexpensive shoes", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + // with synonum for expensive synonym_t synonym1{"costly-expensive", {"costly"}, {{"expensive"}} }; coll1->add_synonym(synonym1); diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index a6c33239..5d800afe 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -282,3 +282,17 @@ TEST(StringUtilsTest, ShouldTrimCurlySpaces) { ASSERT_EQ("{}", StringUtils::trim_curly_spaces("{ }")); ASSERT_EQ("foo {bar} {baz}", StringUtils::trim_curly_spaces("foo { bar } { baz}")); } + +TEST(StringUtilsTest, ContainsWord) { + ASSERT_TRUE(StringUtils::contains_word("foo bar", "foo")); + ASSERT_TRUE(StringUtils::contains_word("foo bar", "bar")); + ASSERT_TRUE(StringUtils::contains_word("foo bar baz", "bar")); + ASSERT_TRUE(StringUtils::contains_word("foo bar baz", "foo bar")); + ASSERT_TRUE(StringUtils::contains_word("foo bar baz", "bar baz")); + + ASSERT_FALSE(StringUtils::contains_word("foobar", "bar")); + ASSERT_FALSE(StringUtils::contains_word("foobar", "foo")); + ASSERT_FALSE(StringUtils::contains_word("foobar baz", "bar")); + ASSERT_FALSE(StringUtils::contains_word("foobar baz", "bar baz")); + ASSERT_FALSE(StringUtils::contains_word("baz foobar", "foo")); +}