fixing tests

This commit is contained in:
krunal 2023-10-06 14:23:01 +05:30
parent ad48c4170e
commit 073f7a1630
4 changed files with 45 additions and 9 deletions

View File

@ -544,7 +544,7 @@ private:
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200);
std::vector<group_by_field_it_t> get_group_by_field_iterators(const std::vector<std::string>&) const;
std::vector<group_by_field_it_t> get_group_by_field_iterators(const std::vector<std::string>&, bool is_reverse=false) const;
public:
// for limiting number of results on multiple candidates / query rewrites

View File

@ -100,6 +100,7 @@ public:
void reset_cache();
[[nodiscard]] bool valid() const;
void next();
void previous();
void skip_to(uint32_t id);
void set_index(uint32_t index);
[[nodiscard]] uint32_t id() const;
@ -174,6 +175,8 @@ public:
iterator_t new_iterator(block_t* start_block = nullptr, block_t* end_block = nullptr, uint32_t field_id = 0);
iterator_t new_rev_iterator();
static void merge(const std::vector<posting_list_t*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(const std::vector<posting_list_t*>& posting_lists, std::vector<uint32_t>& result_ids);

View File

@ -1259,14 +1259,15 @@ int64_t Index::get_doc_val_from_sort_index(sort_index_iterator sort_index_it, ui
return INT64_MAX;
}
std::vector<group_by_field_it_t> Index::get_group_by_field_iterators(const std::vector<std::string>& group_by_fields) const {
std::vector<group_by_field_it_t> Index::get_group_by_field_iterators(const std::vector<std::string>& group_by_fields,
bool is_reverse) const {
std::vector<group_by_field_it_t> group_by_field_it_vec;
for (const auto &field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
auto facet_index_it = is_reverse ? facet_index->new_rev_iterator() : facet_index->new_iterator();
group_by_field_it_t group_by_field_it_struct {field_name, std::move(facet_index_it)};
group_by_field_it_vec.emplace_back(std::move(group_by_field_it_struct));
@ -2393,7 +2394,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
std::vector<group_by_field_it_t> group_by_field_it_vec;
if (group_limit != 0) {
group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
group_by_field_it_vec = get_group_by_field_iterators(group_by_fields, true);
}
while (it.valid()) {
@ -2402,8 +2403,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
if (group_limit != 0) {
distinct_id = 1;
for(auto& kv : group_by_field_it_vec) {
auto facet_index_it = kv.it.clone();
get_distinct_id(kv.field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
kv.it.previous();
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
@ -2492,6 +2493,11 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
} else {
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor);
}
std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
return x.second < y.second;
});
filter_result_iterator->reset();
if (filter_result_iterator->is_valid && !filter_result_iterator->reference.empty()) {
@ -2532,9 +2538,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
if (group_limit != 0) {
distinct_id = 1;
for(auto &kv : group_by_field_it_vec) {
//auto facet_index_it = kv.it.clone();
auto &facet_index_it = kv.it;
get_distinct_id(kv.field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {

View File

@ -993,6 +993,16 @@ posting_list_t::iterator_t posting_list_t::new_iterator(block_t* start_block, bl
return posting_list_t::iterator_t(&id_block_map, start_block, end_block, true, field_id);
}
posting_list_t::iterator_t posting_list_t::new_rev_iterator() {
block_t* start_block = nullptr;
if(!id_block_map.empty()) {
start_block = id_block_map.rbegin()->second;
}
auto rev_it = posting_list_t::iterator_t(&id_block_map, start_block, nullptr, true);
return rev_it;
}
void posting_list_t::advance_all(std::vector<posting_list_t::iterator_t>& its) {
for(auto& it: its) {
it.next();
@ -1677,6 +1687,25 @@ void posting_list_t::iterator_t::next() {
}
}
void posting_list_t::iterator_t::previous() {
curr_index--;
if(curr_index < 0) {
// since block stores only the next pointer, we have to use `id_block_map` for reverse iteration
auto last_ele = ids[curr_block->size()-1];
auto it = id_block_map->find(last_ele);
if(it != id_block_map->end() && it != id_block_map->begin()) {
it--;
curr_block = it->second;
curr_index = curr_block->size()-1;
delete [] ids;
ids = curr_block->ids.uncompress();
} else {
curr_block = end_block;
}
}
}
uint32_t posting_list_t::iterator_t::last_block_id() const {
auto size = curr_block->size();
if(size == 0) {