diff --git a/include/collection.h b/include/collection.h index 529af651..57107cd2 100644 --- a/include/collection.h +++ b/include/collection.h @@ -497,7 +497,8 @@ public: const size_t remote_embedding_num_tries = 2, const std::string& stopwords_set="", const std::vector& facet_return_parent = {}, - const std::vector& ref_include_fields_vec = {}) const; + const std::vector& ref_include_fields_vec = {}, + const drop_tokens_mode_t drop_tokens_mode = right_to_left) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/field.h b/include/field.h index 9798e814..ccdb8784 100644 --- a/include/field.h +++ b/include/field.h @@ -367,10 +367,6 @@ struct field { return Option(400, "Field `" + field.name + "` must be an optional field."); } - if(!field.index && !field.optional) { - return Option(400, "Field `" + field.name + "` must be optional since it is marked as non-indexable."); - } - if(field.name == ".*" && !field.index) { return Option(400, "Field `" + field.name + "` cannot be marked as non-indexable."); } diff --git a/include/index.h b/include/index.h index fa2b5544..c4e65e16 100644 --- a/include/index.h +++ b/include/index.h @@ -98,6 +98,11 @@ enum text_match_type_t { max_weight }; +enum drop_tokens_mode_t { + left_to_right, + right_to_left, +}; + struct search_args { std::vector field_query_tokens; std::vector search_fields; @@ -146,6 +151,7 @@ struct search_args { vector_query_t& vector_query; size_t facet_sample_percent; size_t facet_sample_threshold; + drop_tokens_mode_t drop_tokens_mode; search_args(std::vector field_query_tokens, std::vector search_fields, const text_match_type_t match_type, @@ -161,7 +167,7 @@ struct search_args { size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query, - size_t facet_sample_percent, size_t facet_sample_threshold) : + size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_mode_t drop_tokens_mode) : field_query_tokens(field_query_tokens), search_fields(search_fields), match_type(match_type), filter_tree_root(filter_tree_root), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), @@ -176,7 +182,8 @@ struct search_args { infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix), facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits), split_join_tokens(split_join_tokens), vector_query(vector_query), - facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold) { + facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold), + drop_tokens_mode(drop_tokens_mode) { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory topster = new Topster(topster_size, group_limit); @@ -641,7 +648,8 @@ public: const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, enable_t split_join_tokens, const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, - const std::string& collection_name, facet_index_type_t facet_index_type = DETECT) const; + const std::string& collection_name, facet_index_type_t facet_index_type = DETECT, + const drop_tokens_mode_t drop_tokens_mode = right_to_left) const; void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name, const bool is_update); diff --git a/src/collection.cpp b/src/collection.cpp index 2742047a..12f12e90 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1389,7 +1389,8 @@ Option Collection::search(std::string raw_query, const size_t remote_embedding_num_tries, const std::string& stopwords_set, const std::vector& facet_return_parent, - const std::vector& ref_include_fields_vec) const { + const std::vector& ref_include_fields_vec, + const drop_tokens_mode_t drop_tokens_mode) const { std::shared_lock lock(mutex); @@ -1888,7 +1889,7 @@ Option Collection::search(std::string raw_query, min_len_1typo, min_len_2typo, max_candidates, infixes, max_extra_prefix, max_extra_suffix, facet_query_num_typos, filter_curated_hits, split_join_tokens, vector_query, - facet_sample_percent, facet_sample_threshold); + facet_sample_percent, facet_sample_threshold, drop_tokens_mode); std::unique_ptr search_params_guard(search_params); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f7727688..78c99a2d 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -976,6 +976,8 @@ Option CollectionManager::do_search(std::map& re const char *FACET_SAMPLE_PERCENT = "facet_sample_percent"; const char *FACET_SAMPLE_THRESHOLD = "facet_sample_threshold"; + const char *DROP_TOKENS_MODE = "drop_tokens_mode"; + // enrich params with values from embedded params for(auto& item: embedded_params.items()) { if(item.key() == "expires_at") { @@ -1096,6 +1098,8 @@ Option CollectionManager::do_search(std::map& re size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; + std::string drop_tokens_mode_str = "right_to_left"; + std::unordered_map unsigned_int_values = { {MIN_LEN_1TYPO, &min_len_1typo}, {MIN_LEN_2TYPO, &min_len_2typo}, @@ -1132,6 +1136,7 @@ Option CollectionManager::do_search(std::map& re {HIGHLIGHT_END_TAG, &highlight_end_tag}, {PINNED_HITS, &pinned_hits_str}, {HIDDEN_HITS, &hidden_hits_str}, + {DROP_TOKENS_MODE, &drop_tokens_mode_str}, }; std::unordered_map bool_values = { @@ -1293,6 +1298,12 @@ Option CollectionManager::do_search(std::map& re Index::NUM_CANDIDATES_DEFAULT_MIN); } + auto drop_tokens_mode_op = magic_enum::enum_cast(drop_tokens_mode_str); + drop_tokens_mode_t drop_tokens_mode; + if(drop_tokens_mode_op.has_value()) { + drop_tokens_mode = drop_tokens_mode_op.value(); + } + Option result_op = collection->search(raw_query, search_fields, filter_query, facet_fields, sort_fields, num_typos, per_page, @@ -1341,7 +1352,8 @@ Option CollectionManager::do_search(std::map& re remote_embedding_num_tries, stopwords_set, facet_return_parent, - ref_include_fields_vec); + ref_include_fields_vec, + drop_tokens_mode); uint64_t timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); diff --git a/src/field.cpp b/src/field.cpp index 2c5199ad..7beef41c 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -549,6 +549,10 @@ Option field::flatten_doc(nlohmann::json& document, std::unordered_map flattened_fields_map; for(auto& nested_field: nested_fields) { + if(!nested_field.index) { + continue; + } + std::vector field_parts; StringUtils::split(nested_field.name, field_parts, "."); diff --git a/src/index.cpp b/src/index.cpp index fc9c57ed..490b044f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1795,7 +1795,8 @@ Option Index::run_search(search_args* search_params, const std::string& co search_params->facet_sample_percent, search_params->facet_sample_threshold, collection_name, - facet_index_type); + facet_index_type, + search_params->drop_tokens_mode); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -2244,7 +2245,9 @@ Option Index::search(std::vector& field_query_tokens, cons const bool filter_curated_hits, const enable_t split_join_tokens, const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, - const std::string& collection_name, facet_index_type_t facet_index_type) const { + const std::string& collection_name, + facet_index_type_t facet_index_type, + const drop_tokens_mode_t drop_tokens_mode) const { std::shared_lock lock(mutex); auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); @@ -2644,16 +2647,24 @@ Option Index::search(std::vector& field_query_tokens, cons for (size_t qi = 0; qi < all_queries.size(); qi++) { auto& orig_tokens = all_queries[qi]; size_t num_tokens_dropped = 0; + auto curr_direction = drop_tokens_mode; + size_t total_dirs_done = 0; while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) { // When atleast two tokens from the query are available we can drop one std::vector truncated_tokens; std::vector dropped_tokens; - if(orig_tokens.size() > 1 && num_tokens_dropped < 2*(orig_tokens.size()-1)) { - bool prefix_search = false; + if(num_tokens_dropped >= orig_tokens.size() - 1) { + // swap direction and reset counter + curr_direction = (curr_direction == right_to_left) ? left_to_right : right_to_left; + num_tokens_dropped = 0; + total_dirs_done++; + } - if (num_tokens_dropped < orig_tokens.size() - 1) { + if(orig_tokens.size() > 1 && total_dirs_done < 2) { + bool prefix_search = false; + if (curr_direction == right_to_left) { // drop from right size_t truncated_len = orig_tokens.size() - num_tokens_dropped - 1; for (size_t i = 0; i < orig_tokens.size(); i++) { @@ -2666,7 +2677,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { // drop from left prefix_search = true; - size_t start_index = (num_tokens_dropped + 1) - orig_tokens.size() + 1; + size_t start_index = (num_tokens_dropped + 1); for(size_t i = 0; i < orig_tokens.size(); i++) { if(i >= start_index) { truncated_tokens.emplace_back(orig_tokens[i]); diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 2dc8dd04..08d02490 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1279,7 +1279,7 @@ TEST_F(CollectionAllFieldsTest, DoNotIndexFieldMarkedAsNonIndex) { auto op = collectionManager.create_collection("coll2", 1, fields, "", 0, field_types::AUTO); ASSERT_FALSE(op.ok()); - ASSERT_EQ("Field `post` must be optional since it is marked as non-indexable.", op.error()); + ASSERT_EQ("Field `.*_txt` cannot be a facet since it's marked as non-indexable.", op.error()); fields = {field("company_name", field_types::STRING, false), field("num_employees", field_types::INT32, false), diff --git a/test/collection_nested_fields_test.cpp b/test/collection_nested_fields_test.cpp index 98a94f37..f92f7e0b 100644 --- a/test/collection_nested_fields_test.cpp +++ b/test/collection_nested_fields_test.cpp @@ -1487,6 +1487,54 @@ TEST_F(CollectionNestedFieldsTest, ExplicitSchemaForNestedArrayTypeValidation) { "Hint: field inside an array of objects must be an array type as well.", add_op.error()); } +TEST_F(CollectionNestedFieldsTest, UnindexedNestedFieldShouldNotClutterSchema) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name": "block", "type": "object", "optional": true, "index": false} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll1 = op.get(); + + auto doc1 = R"({ + "block": {"text": "Hello world."} + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // child fields should not become part of schema + ASSERT_EQ(1, coll1->get_fields().size()); +} + +TEST_F(CollectionNestedFieldsTest, UnindexedNonOptionalFieldShouldBeAllowed) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name": "block", "type": "object", "index": false} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll1 = op.get(); + + auto doc1 = R"({ + "block": {"text": "Hello world."} + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // child fields should not become part of schema + ASSERT_EQ(1, coll1->get_fields().size()); +} + TEST_F(CollectionNestedFieldsTest, SortByNestedField) { nlohmann::json schema = R"({ "name": "coll1", diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 1c77decb..d6538c86 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -2332,6 +2332,47 @@ TEST_F(CollectionSpecificMoreTest, ExhaustiveSearchWithoutExplicitDropTokens) { ASSERT_EQ(2, res["hits"].size()); } +TEST_F(CollectionSpecificMoreTest, DropTokensLeftToRightFirst) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["title"] = "alpha beta"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["title"] = "beta gamma"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + bool exhaustive_search = false; + size_t drop_tokens_threshold = 1; + + auto res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, left_to_right).get(); + + ASSERT_EQ(1, res["hits"].size()); + ASSERT_EQ("1", res["hits"][0]["document"]["id"].get()); + + res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 10000, + 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, + 0, HASH, 30000, 2, "", {}, {}, right_to_left).get(); + + ASSERT_EQ(1, res["hits"].size()); + ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); +} + TEST_F(CollectionSpecificMoreTest, DoNotHighlightFieldsForSpecialCharacterQuery) { nlohmann::json schema = R"({ "name": "coll1",