diff --git a/include/num_tree.h b/include/num_tree.h index 444f6266..38c57c8a 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -65,4 +65,9 @@ public: uint32_t* const& context_ids, size_t& result_ids_len, uint32_t*& result_ids) const; + + void merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const; }; \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 17d2c987..488a9b8c 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -651,22 +651,40 @@ void filter_result_iterator_t::init() { for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); + std::vector id_list_iterators; + std::vector expanded_id_lists; - size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + value, id_list_iterators, expanded_id_lists); } - filter_result.count = result_size; + uint32_t* filter_match_ids = nullptr; + uint32_t filter_ids_length; + num_tree->merge_id_list_iterators(id_list_iterators, a_filter.comparators[fi], + filter_match_ids, filter_ids_length); + + if (a_filter.comparators[fi] == NOT_EQUALS) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_match_ids, filter_ids_length); + } + + uint32_t *out = nullptr; + filter_result.count = ArrayUtils::or_scalar(filter_match_ids, filter_ids_length, + filter_result.docs, filter_result.count, &out); + + delete [] filter_match_ids; + delete [] filter_result.docs; + filter_result.docs = out; + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } } if (a_filter.apply_not_equals) { diff --git a/src/num_tree.cpp b/src/num_tree.cpp index 89c5e3a0..d254ca43 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -429,3 +429,40 @@ num_tree_t::~num_tree_t() { ids_t::destroy_list(kv.second); } } + +void num_tree_t::merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const { + struct comp { + bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { + return lhs->id() > rhs->id(); + } + }; + + std::priority_queue, comp> iter_queue; + for (auto& id_list_iterator: id_list_iterators) { + if (id_list_iterator.valid()) { + iter_queue.push(&id_list_iterator); + } + } + + std::vector consolidated_ids; + while (!iter_queue.empty()) { + id_list_t::iterator_t* iter = iter_queue.top(); + iter_queue.pop(); + + consolidated_ids.push_back(iter->id()); + iter->next(); + + if (iter->valid()) { + iter_queue.push(iter); + } + } + + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + result_ids_len = consolidated_ids.size(); + result_ids = new uint32_t[consolidated_ids.size()]; + std::copy(consolidated_ids.begin(), consolidated_ids.end(), result_ids); +}