diff --git a/src/collection.cpp b/src/collection.cpp index 2d7c9873..f1f79e97 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1152,6 +1152,11 @@ void Collection::parse_search_query(const std::string &query, std::vector leaf_to_indices; std::vector query_suggestion; + if(searched_queries.size() <= field_order_kv->query_index) { + return ; + } + for (const art_leaf *token_leaf : searched_queries[field_order_kv->query_index]) { // Must search for the token string fresh on that field for the given document since `token_leaf` // is from the best matched field and need not be present in other fields of a document. diff --git a/src/index.cpp b/src/index.cpp index 7db552ad..ed18b108 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1325,6 +1325,31 @@ void Index::search(Option & outcome, //auto begin = std::chrono::high_resolution_clock::now(); uint32_t* all_result_ids = nullptr; + const size_t num_search_fields = std::min(search_fields.size(), (size_t) FIELD_LIMIT_NUM); + uint32_t *exclude_token_ids = nullptr; + size_t exclude_token_ids_size = 0; + + // find documents that contain the excluded tokens to exclude them from results later + for(size_t i = 0; i < num_search_fields; i++) { + const std::string & field_name = search_fields[i]; + for(const std::string& exclude_token: q_exclude_tokens) { + art_leaf* leaf = (art_leaf *) art_search(search_index.at(field_name), + (const unsigned char *) exclude_token.c_str(), + exclude_token.size() + 1); + + if(leaf) { + uint32_t *ids = leaf->values->ids.uncompress(); + uint32_t *exclude_token_ids_merged = nullptr; + exclude_token_ids_size = ArrayUtils::or_scalar(exclude_token_ids, exclude_token_ids_size, ids, + leaf->values->ids.getLength(), + &exclude_token_ids_merged); + delete[] ids; + delete[] exclude_token_ids; + exclude_token_ids = exclude_token_ids_merged; + } + } + } + if(!q_include_tokens.empty() && q_include_tokens[0] == "*") { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); const std::string & field = search_fields[0]; @@ -1364,6 +1389,20 @@ void Index::search(Option & outcome, filter_ids = excluded_result_ids; } + // Exclude document IDs associated with excluded tokens from the result set + if(exclude_token_ids_size != 0) { + if(filters.empty()) { + // filter ids populated from hash map will not be sorted, but sorting is required for intersection + std::sort(filter_ids, filter_ids+filter_ids_length); + } + + uint32_t *excluded_result_ids = nullptr; + filter_ids_length = ArrayUtils::exclude_scalar(filter_ids, filter_ids_length, exclude_token_ids, + exclude_token_ids_size, &excluded_result_ids); + delete[] filter_ids; + filter_ids = excluded_result_ids; + } + score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {}, groups_processed, filter_ids, filter_ids_length); collate_included_ids(q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries); @@ -1372,31 +1411,6 @@ void Index::search(Option & outcome, all_result_ids = filter_ids; filter_ids = nullptr; } else { - const size_t num_search_fields = std::min(search_fields.size(), (size_t) FIELD_LIMIT_NUM); - uint32_t *exclude_token_ids = nullptr; - size_t exclude_token_ids_size = 0; - - // find documents that contain the excluded tokens to exclude them from results later - for(size_t i = 0; i < num_search_fields; i++) { - const std::string & field_name = search_fields[i]; - for(const std::string& exclude_token: q_exclude_tokens) { - art_leaf* leaf = (art_leaf *) art_search(search_index.at(field_name), - (const unsigned char *) exclude_token.c_str(), - exclude_token.size() + 1); - - if(leaf) { - uint32_t *ids = leaf->values->ids.uncompress(); - uint32_t *exclude_token_ids_merged = nullptr; - exclude_token_ids_size = ArrayUtils::or_scalar(exclude_token_ids, exclude_token_ids_size, ids, - leaf->values->ids.getLength(), - &exclude_token_ids_merged); - delete[] ids; - delete[] exclude_token_ids; - exclude_token_ids = exclude_token_ids_merged; - } - } - } - for(size_t i = 0; i < num_search_fields; i++) { // proceed to query search only when no filters are provided or when filtering produces results if(filters.empty() || filter_ids_length > 0) { @@ -1416,10 +1430,10 @@ void Index::search(Option & outcome, collate_included_ids(q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries); } } - - delete [] exclude_token_ids; } + delete [] exclude_token_ids; + do_facets(facets, facet_query, all_result_ids, all_result_ids_len); do_facets(facets, facet_query, &included_ids[0], included_ids.size()); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 63b2701b..7ecdcbf5 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -220,6 +220,15 @@ TEST_F(CollectionTest, SearchWithExcludedTokens) { std::string result_id = result["document"]["id"]; ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + results = collection->search("-rocket", query_fields, "", facets, sort_fields, 0, 50).get(); + + ASSERT_EQ(21, results["found"].get()); + ASSERT_EQ(21, results["hits"].size()); + + results = collection->search("-rocket -cryovolcanism", query_fields, "", facets, sort_fields, 0, 50).get(); + + ASSERT_EQ(20, results["found"].get()); } TEST_F(CollectionTest, SkipUnindexedTokensDuringPhraseSearch) {