diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 649da61a..4c7cf196 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,15 +38,17 @@ jobs: uses: bazelbuild/setup-bazelisk@v2 - name: Download bazel cache - uses: dawidd6/action-download-artifact@v2 + uses: dawidd6/action-download-artifact@v2.28.0 with: name: bazel-cache search_artifacts: true workflow_conclusion: "" if_no_artifact_found: warn + skip_unpack: true - name: Uncompress bazel cache run: | + unzip bazel-cache.zip mkdir -p ~/.cache/bazel tar_file="bazel-cache.tar.gz" && \ [ -f "$tar_file" ] && \ diff --git a/include/index.h b/include/index.h index 1d895c02..8020c23f 100644 --- a/include/index.h +++ b/include/index.h @@ -269,11 +269,22 @@ class VectorFilterFunctor: public hnswlib::BaseFilterFunctor { const uint32_t* filter_ids = nullptr; const uint32_t filter_ids_length = 0; + const uint32_t* excluded_ids = nullptr; + const uint32_t excluded_ids_length = 0; + public: - explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) : - filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} + explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length, const uint32_t* excluded_ids = nullptr, const uint32_t excluded_ids_length = 0) : + filter_ids(filter_ids), filter_ids_length(filter_ids_length), excluded_ids(excluded_ids), excluded_ids_length(excluded_ids_length) {} bool operator()(hnswlib::labeltype id) override { + if(filter_ids_length == 0 && excluded_ids_length == 0) { + return true; + } + + if(excluded_ids_length > 0 && excluded_ids && std::binary_search(excluded_ids, excluded_ids + excluded_ids_length, id)) { + return false; + } + if(filter_ids_length == 0) { return true; } diff --git a/src/index.cpp b/src/index.cpp index 2e828aab..d39e3fa7 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2901,7 +2901,7 @@ Option Index::search(std::vector& field_query_tokens, cons k++; } - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); + VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; @@ -3207,7 +3207,7 @@ Option Index::search(std::vector& field_query_tokens, cons const float VECTOR_SEARCH_WEIGHT = vector_query.alpha; const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT; - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); + VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; // use k as 100 by default for ensuring results stability in pagination diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 72244efe..299b001c 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2822,4 +2822,82 @@ TEST_F(CollectionVectorTest, TestSemanticSearchAfterUpdate) { ASSERT_TRUE(result.ok()); ASSERT_EQ(1, result.get()["hits"].size()); ASSERT_EQ("potato", result.get()["hits"][0]["document"]["name"]); +} + + +TEST_F(CollectionVectorTest, TestHybridSearchHiddenHits) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "embedding", + "type": "float[]", + "embed": { + "from": [ + "name" + ], + "model_config": { + "model_name": "ts/e5-small" + } + } + } + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + auto add_op = coll->add(R"({ + "name": "soccer", + "id": "0" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "guitar", + "id": "1" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "typesense", + "id": "2" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + add_op = coll->add(R"({ + "name": "potato", + "id": "3" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + auto results = coll->search("sports", {"name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(4, results["hits"].size()); + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + + // do hybrid search with hidden_hits + auto hybrid_results = coll->search("sports", {"name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "", 1, "", "0").get(); + + ASSERT_EQ(3, hybrid_results["hits"].size()); + ASSERT_FALSE(hybrid_results["hits"][0]["document"]["id"] == 0); } \ No newline at end of file