Merge pull request #1393 from ozanarmagan/v0.25-join

Fix hybrid search with filters
This commit is contained in:
Kishore Nallan 2023-11-24 11:06:17 +05:30 committed by GitHub
commit c9ee8c9128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 5 deletions

View File

@ -38,15 +38,17 @@ jobs:
uses: bazelbuild/setup-bazelisk@v2
- name: Download bazel cache
uses: dawidd6/action-download-artifact@v2
uses: dawidd6/action-download-artifact@v2.28.0
with:
name: bazel-cache
search_artifacts: true
workflow_conclusion: ""
if_no_artifact_found: warn
skip_unpack: true
- name: Uncompress bazel cache
run: |
unzip bazel-cache.zip
mkdir -p ~/.cache/bazel
tar_file="bazel-cache.tar.gz" && \
[ -f "$tar_file" ] && \

View File

@ -269,11 +269,22 @@ class VectorFilterFunctor: public hnswlib::BaseFilterFunctor {
const uint32_t* filter_ids = nullptr;
const uint32_t filter_ids_length = 0;
const uint32_t* excluded_ids = nullptr;
const uint32_t excluded_ids_length = 0;
public:
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length, const uint32_t* excluded_ids = nullptr, const uint32_t excluded_ids_length = 0) :
filter_ids(filter_ids), filter_ids_length(filter_ids_length), excluded_ids(excluded_ids), excluded_ids_length(excluded_ids_length) {}
bool operator()(hnswlib::labeltype id) override {
if(filter_ids_length == 0 && excluded_ids_length == 0) {
return true;
}
if(excluded_ids_length > 0 && excluded_ids && std::binary_search(excluded_ids, excluded_ids + excluded_ids_length, id)) {
return false;
}
if(filter_ids_length == 0) {
return true;
}

View File

@ -2901,7 +2901,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
k++;
}
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
@ -3207,7 +3207,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
const float VECTOR_SEARCH_WEIGHT = vector_query.alpha;
const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT;
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination

View File

@ -2822,4 +2822,82 @@ TEST_F(CollectionVectorTest, TestSemanticSearchAfterUpdate) {
ASSERT_TRUE(result.ok());
ASSERT_EQ(1, result.get()["hits"].size());
ASSERT_EQ("potato", result.get()["hits"][0]["document"]["name"]);
}
TEST_F(CollectionVectorTest, TestHybridSearchHiddenHits) {
nlohmann::json schema = R"({
"name": "test",
"fields": [
{
"name": "name",
"type": "string"
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"name"
],
"model_config": {
"model_name": "ts/e5-small"
}
}
}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
auto add_op = coll->add(R"({
"name": "soccer",
"id": "0"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "guitar",
"id": "1"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "typesense",
"id": "2"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "potato",
"id": "3"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
auto results = coll->search("sports", {"name", "embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
ASSERT_EQ(4, results["hits"].size());
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
// do hybrid search with hidden_hits
auto hybrid_results = coll->search("sports", {"name", "embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>(), spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 1, "", "0").get();
ASSERT_EQ(3, hybrid_results["hits"].size());
ASSERT_FALSE(hybrid_results["hits"][0]["document"]["id"] == 0);
}