mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 23:06:30 +08:00
Merge pull request #1393 from ozanarmagan/v0.25-join
Fix hybrid search with filters
This commit is contained in:
commit
c9ee8c9128
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@ -38,15 +38,17 @@ jobs:
|
||||
uses: bazelbuild/setup-bazelisk@v2
|
||||
|
||||
- name: Download bazel cache
|
||||
uses: dawidd6/action-download-artifact@v2
|
||||
uses: dawidd6/action-download-artifact@v2.28.0
|
||||
with:
|
||||
name: bazel-cache
|
||||
search_artifacts: true
|
||||
workflow_conclusion: ""
|
||||
if_no_artifact_found: warn
|
||||
skip_unpack: true
|
||||
|
||||
- name: Uncompress bazel cache
|
||||
run: |
|
||||
unzip bazel-cache.zip
|
||||
mkdir -p ~/.cache/bazel
|
||||
tar_file="bazel-cache.tar.gz" && \
|
||||
[ -f "$tar_file" ] && \
|
||||
|
@ -269,11 +269,22 @@ class VectorFilterFunctor: public hnswlib::BaseFilterFunctor {
|
||||
const uint32_t* filter_ids = nullptr;
|
||||
const uint32_t filter_ids_length = 0;
|
||||
|
||||
const uint32_t* excluded_ids = nullptr;
|
||||
const uint32_t excluded_ids_length = 0;
|
||||
|
||||
public:
|
||||
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
|
||||
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
|
||||
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length, const uint32_t* excluded_ids = nullptr, const uint32_t excluded_ids_length = 0) :
|
||||
filter_ids(filter_ids), filter_ids_length(filter_ids_length), excluded_ids(excluded_ids), excluded_ids_length(excluded_ids_length) {}
|
||||
|
||||
bool operator()(hnswlib::labeltype id) override {
|
||||
if(filter_ids_length == 0 && excluded_ids_length == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if(excluded_ids_length > 0 && excluded_ids && std::binary_search(excluded_ids, excluded_ids + excluded_ids_length, id)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if(filter_ids_length == 0) {
|
||||
return true;
|
||||
}
|
||||
|
@ -2901,7 +2901,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
k++;
|
||||
}
|
||||
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
@ -3207,7 +3207,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
const float VECTOR_SEARCH_WEIGHT = vector_query.alpha;
|
||||
const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT;
|
||||
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
|
@ -2822,4 +2822,82 @@ TEST_F(CollectionVectorTest, TestSemanticSearchAfterUpdate) {
|
||||
ASSERT_TRUE(result.ok());
|
||||
ASSERT_EQ(1, result.get()["hits"].size());
|
||||
ASSERT_EQ("potato", result.get()["hits"][0]["document"]["name"]);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionVectorTest, TestHybridSearchHiddenHits) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": [
|
||||
"name"
|
||||
],
|
||||
"model_config": {
|
||||
"model_name": "ts/e5-small"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
auto add_op = coll->add(R"({
|
||||
"name": "soccer",
|
||||
"id": "0"
|
||||
})"_json.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "guitar",
|
||||
"id": "1"
|
||||
})"_json.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "typesense",
|
||||
"id": "2"
|
||||
})"_json.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "potato",
|
||||
"id": "3"
|
||||
})"_json.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
auto results = coll->search("sports", {"name", "embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
|
||||
// do hybrid search with hidden_hits
|
||||
auto hybrid_results = coll->search("sports", {"name", "embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>(), spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 1, "", "0").get();
|
||||
|
||||
ASSERT_EQ(3, hybrid_results["hits"].size());
|
||||
ASSERT_FALSE(hybrid_results["hits"][0]["document"]["id"] == 0);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user