diff --git a/include/index.h b/include/index.h index d3d47b9e..d62b2e2e 100644 --- a/include/index.h +++ b/include/index.h @@ -816,8 +816,10 @@ public: const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const infix_t& field_infix, const uint8_t field_id, const string& field_name, const std::vector& query_tokens, Topster* actual_topster, - size_t field_num_results, size_t& all_result_ids_len, - spp::sparse_hash_set& groups_processed, uint32_t*& all_result_ids) const; + const uint32_t *filter_ids, size_t filter_ids_length, + const std::vector& curated_ids_sorted, + size_t field_num_results, uint32_t*& all_result_ids, size_t& all_result_ids_len, + spp::sparse_hash_set& groups_processed) const; void do_synonym_search(const std::vector& filters, const std::map>& included_ids_map, diff --git a/src/index.cpp b/src/index.cpp index dffb6a37..154706e7 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2613,8 +2613,8 @@ void Index::search_fields(const std::vector& filters, do_infix_search(sort_fields_std, searched_queries, group_limit, group_by_fields, max_extra_prefix, max_extra_suffix, field_infix, field_id, field_name, query_tokens, - actual_topster, field_num_results, all_result_ids_len, groups_processed, - all_result_ids); + actual_topster, filter_ids, filter_ids_length, + curated_ids_sorted, field_num_results, all_result_ids, all_result_ids_len, groups_processed); } else if(actual_filter_ids_length != 0) { // indicates phrase match query curate_filtered_ids(filters, curated_ids, exclude_token_ids, @@ -2787,8 +2787,11 @@ void Index::do_infix_search(const std::vector& sort_fields_std, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const infix_t& field_infix, const uint8_t field_id, const string& field_name, const std::vector& query_tokens, Topster* actual_topster, - size_t field_num_results, size_t& all_result_ids_len, - spp::sparse_hash_set& groups_processed, uint32_t*& all_result_ids) const { + const uint32_t *filter_ids, size_t filter_ids_length, + const std::vector& curated_ids_sorted, + size_t field_num_results, uint32_t*& all_result_ids, size_t& all_result_ids_len, + spp::sparse_hash_set& groups_processed) const { + if(field_infix == always || (field_infix == fallback && field_num_results == 0)) { std::vector infix_ids; search_infix(query_tokens[0].value, field_name, infix_ids, max_extra_prefix, max_extra_suffix); @@ -2800,21 +2803,49 @@ void Index::do_infix_search(const std::vector& sort_fields_std, populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); uint32_t token_bits = 255; - for(auto seq_id: infix_ids) { + std::sort(infix_ids.begin(), infix_ids.end()); + infix_ids.erase(std::unique( infix_ids.begin(), infix_ids.end() ), infix_ids.end()); + + uint32_t *raw_infix_ids = nullptr; + size_t raw_infix_ids_length = 0; + + if(curated_ids_sorted.size() != 0) { + raw_infix_ids_length = ArrayUtils::exclude_scalar(&infix_ids[0], infix_ids.size(), &curated_ids_sorted[0], + curated_ids_sorted.size(), &raw_infix_ids); + infix_ids.clear(); + } else { + raw_infix_ids = &infix_ids[0]; + raw_infix_ids_length = infix_ids.size(); + } + + if(filter_ids_length != 0) { + uint32_t *filtered_raw_infix_ids = nullptr; + raw_infix_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, raw_infix_ids, + raw_infix_ids_length, &filtered_raw_infix_ids); + if(raw_infix_ids != &infix_ids[0]) { + delete [] raw_infix_ids; + } + + raw_infix_ids = filtered_raw_infix_ids; + } + + for(size_t i = 0; i < raw_infix_ids_length; i++) { + auto seq_id = raw_infix_ids[i]; score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 2, actual_topster, {}, groups_processed, seq_id, sort_order, field_values, geopoint_indices, group_limit, group_by_fields, token_bits, false, false, {}); } - std::sort(infix_ids.begin(), infix_ids.end()); - infix_ids.erase(std::unique( infix_ids.begin(), infix_ids.end() ), infix_ids.end()); - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &infix_ids[0], - infix_ids.size(), &new_all_result_ids); + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, raw_infix_ids, + raw_infix_ids_length, &new_all_result_ids); delete[] all_result_ids; all_result_ids = new_all_result_ids; + + if(raw_infix_ids != &infix_ids[0]) { + delete [] raw_infix_ids; + } } } } diff --git a/test/collection_infix_search_test.cpp b/test/collection_infix_search_test.cpp index b20f8271..528f6460 100644 --- a/test/collection_infix_search_test.cpp +++ b/test/collection_infix_search_test.cpp @@ -112,6 +112,56 @@ TEST_F(CollectionInfixSearchTest, InfixBasics) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionInfixSearchTest, InfixWithFiltering) { + std::vector fields = {field("title", field_types::STRING, false, false, true, "", -1, 1), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "GH100037IN8900X"; + doc1["points"] = 100; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "XH100037IN8900X"; + doc2["points"] = 200; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + auto results = coll1->search("37IN8", + {"title"}, "points: 200", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, "", "", {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + // filtering + exclusion via curation + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["title"] = "RH100037IN8900X"; + doc3["points"] = 300; + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + results = coll1->search("37IN8", {"title"}, "points:>= 200", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, "", "2", {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionInfixSearchTest, RespectPrefixAndSuffixLimits) { std::vector fields = {field("title", field_types::STRING, false, false, true, "", -1, 1), field("points", field_types::INT32, false),};