Refactor ``VectorFilterFunctor`` to include

```excluded_ids```
This commit is contained in:
ozanarmagan 2023-11-23 22:23:15 +03:00
parent 2b154226ca
commit a2c5d24802
2 changed files with 15 additions and 16 deletions

View File

@ -269,11 +269,22 @@ class VectorFilterFunctor: public hnswlib::BaseFilterFunctor {
const uint32_t* filter_ids = nullptr;
const uint32_t filter_ids_length = 0;
const uint32_t* excluded_ids = nullptr;
const uint32_t excluded_ids_length = 0;
public:
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length, const uint32_t* excluded_ids = nullptr, const uint32_t excluded_ids_length = 0) :
filter_ids(filter_ids), filter_ids_length(filter_ids_length), excluded_ids(excluded_ids), excluded_ids_length(excluded_ids_length) {}
bool operator()(hnswlib::labeltype id) override {
if(filter_ids_length == 0 && excluded_ids_length == 0) {
return true;
}
if(excluded_ids_length > 0 && excluded_ids && std::binary_search(excluded_ids, excluded_ids + excluded_ids_length, id)) {
return false;
}
if(filter_ids_length == 0) {
return true;
}

View File

@ -2901,7 +2901,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
k++;
}
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
@ -3206,20 +3206,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// For hybrid search, we need to give weight to text match and vector search
const float VECTOR_SEARCH_WEIGHT = vector_query.alpha;
const float TEXT_MATCH_WEIGHT = 1.0 - VECTOR_SEARCH_WEIGHT;
bool no_filters_provided = (filter_tree_root == nullptr && filter_result.count == 0);
// list of all document ids
if (no_filters_provided) {
filter_result.count = seq_ids->num_ids();
filter_result.docs = seq_ids->uncompress();
}
curate_filtered_ids(curated_ids, excluded_result_ids,
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted);
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination