diff --git a/src/index.cpp b/src/index.cpp index 912069ef..03a928fe 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2045,23 +2045,46 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& &result_ids, result_ids_len, field_num_results, 0, group_by_fields, false, 4, query_hashes, token_order, false, 0, 1, false, -1, 3, 7, 4); - delete [] result_ids; - if(result_ids_len != 0) { - // remove window_tokens from `tokens` - std::vector new_tokens; - for(size_t new_i = start_index; new_i < tokens.size(); new_i++) { - const auto& token = tokens[new_i]; - if(window_tokens_set.count(token) == 0) { - new_tokens.emplace_back(token); - } else { - absorbed_tokens.insert(token); - field_absorbed_tokens.emplace_back(token); + // we need to narraw onto the exact matches + std::vector posting_lists; + art_tree* t = search_index.at(field_name); + + for(auto& w_token: window_tokens) { + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) w_token.value.c_str(), + w_token.value.length()+1); + if(leaf == nullptr) { + continue; } + + posting_lists.push_back(leaf->values); } - tokens = new_tokens; - return true; + uint32_t* exact_strt_ids = new uint32_t[result_ids_len]; + size_t exact_strt_size = 0; + + posting_t::get_exact_matches(posting_lists, field_it->second.is_array(), result_ids, result_ids_len, + exact_strt_ids, exact_strt_size); + + delete [] result_ids; + delete [] exact_strt_ids; + + if(exact_strt_size != 0) { + // remove window_tokens from `tokens` + std::vector new_tokens; + for(size_t new_i = start_index; new_i < tokens.size(); new_i++) { + const auto& token = tokens[new_i]; + if(window_tokens_set.count(token) == 0) { + new_tokens.emplace_back(token); + } else { + absorbed_tokens.insert(token); + field_absorbed_tokens.emplace_back(token); + } + } + + tokens = new_tokens; + return true; + } } if(!slide_window) { diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index faad8572..9caecaaa 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -2061,3 +2061,55 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithJustRemoveTokens) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionOverrideTest, DynamicFilteringWithPartialTokenMatch) { + // when query tokens do not match placeholder field value exactly, don't do filtering + Collection* coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("category", field_types::STRING, true),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields).get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["category"] = "Running Shoes"; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + + std::vector sort_fields = {sort_by("_text_match", "DESC")}; + + auto results = coll1->search("shoes", {"name"}, "", + {}, sort_fields, {0}, 10).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // with override, we return all records + + nlohmann::json override_json = { + {"id", "dynamic-filter"}, + { + "rule", { + {"query", "{ category }"}, + {"match", override_t::MATCH_EXACT} + } + }, + {"filter_by", "category:= {category}"}, + {"remove_matched_tokens", true} + }; + + override_t override; + auto op = override_t::parse(override_json, "dynamic-filter", override); + ASSERT_TRUE(op.ok()); + coll1->add_override(override); + + results = coll1->search("shoes", {"name"}, "", + {}, sort_fields, {0}, 10).get(); + + ASSERT_EQ(1, results["hits"].size()); + collectionManager.drop_collection("coll1"); +}