From 6bc60adbaef559f2c18711dfe0b2a465f8bdecd3 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Wed, 13 Dec 2023 12:12:33 +0530 Subject: [PATCH] Support overriding wildcard query. --- src/collection.cpp | 4 ++ test/collection_all_fields_test.cpp | 2 +- test/collection_override_test.cpp | 102 +++++++++++++++++++++++++++- 3 files changed, 104 insertions(+), 4 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index 8d3fd554..b523ecb0 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2155,6 +2155,10 @@ Option Collection::search(std::string raw_query, parse_search_query(query, q_include_tokens, field_query_tokens[0].q_exclude_tokens, field_query_tokens[0].q_phrases, "", false, stopwords_set); + + process_filter_overrides(filter_overrides, q_include_tokens, token_order, filter_tree_root, + included_ids, excluded_ids, override_metadata); + for(size_t i = 0; i < q_include_tokens.size(); i++) { auto& q_include_token = q_include_tokens[i]; field_query_tokens[0].q_include_tokens.emplace_back(i, q_include_token, (i == q_include_tokens.size() - 1), diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index ae9f0d98..fab4adb3 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1591,7 +1591,7 @@ TEST_F(CollectionAllFieldsTest, FieldNameMatchingRegexpShouldNotBeIndexedInNonAu } TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) { - EmbedderManager::set_model_dir("/tmp/typensense_test/models"); + EmbedderManager::set_model_dir("/tmp/typesense_test/models"); nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 8a97bb18..6f78041b 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -3728,7 +3728,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) { // includes instead of filter_by coll1->remove_override("ov-1"); - override_json1 = R"({ + auto override_json2 = R"({ "id": "ov-1", "rule": { "tags": ["*"] @@ -3738,9 +3738,10 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) { ] })"_json; - op = override_t::parse(override_json1, "ov-1", override1); + override_t override2; + op = override_t::parse(override_json2, "ov-2", override2); ASSERT_TRUE(op.ok()); - coll1->add_override(override1); + coll1->add_override(override2); results = coll1->search("foobar", {"name"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, @@ -3907,3 +3908,98 @@ TEST_F(CollectionOverrideTest, MetadataValidation) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionOverrideTest, WildcardSearchOverride) { + Collection* coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("category", field_types::STRING, true),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "queryA"; + doc1["category"] = "kids"; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "queryA"; + doc2["category"] = "kitchen"; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Clay Toy"; + doc3["category"] = "home"; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = {sort_by("_text_match", "DESC")}; + + nlohmann::json override_json1 = R"({ + "id": "ov-1", + "rule": { + "query": "*", + "match": "exact" + }, + "filter_by": "category: kids" + })"_json; + + override_t override1; + auto op = override_t::parse(override_json1, "ov-1", override1); + ASSERT_TRUE(op.ok()); + coll1->add_override(override1); + + std::string override_tags = ""; + auto results = coll1->search("*", {}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "right_to_left", + true, true, false, -1, "", override_tags).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + + // includes instead of filter_by + coll1->remove_override("ov-1"); + + override_t override2; + auto override_json2 = R"({ + "id": "ov-2", + "rule": { + "query": "*", + "match": "exact" + }, + "includes": [ + {"id": "1", "position": 1} + ] + })"_json; + + op = override_t::parse(override_json2, "ov-2", override2); + ASSERT_TRUE(op.ok()); + coll1->add_override(override2); + + results = coll1->search("*", {}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, "right_to_left", + true, true, false, -1, "", override_tags).get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +}