Add filter_result_t struct.

Add `reference_filter_result_t` struct.
Add support for lazy filtering.
Update `rearrange_filter_tree` to return approximate count of filter matches.
This commit is contained in:
Harpreet Sangar 2023-03-03 10:37:33 +05:30
parent 2d39461eca
commit e78d209911
13 changed files with 541 additions and 131 deletions

View File

@ -5,5 +5,3 @@ build --cxxopt="-std=c++17"
test --jobs=6
build --enable_platform_specific_config
build:linux --action_env=BAZEL_LINKLIBS="-l%:libstdc++.a -l%:libgcc.a"

View File

@ -268,6 +268,8 @@ private:
Option<std::string> get_reference_field(const std::string & collection_name) const;
public:
enum {MAX_ARRAY_MATCHES = 5};
@ -455,16 +457,12 @@ public:
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
Option<std::string> get_reference_field(const std::string & collection_name) const;
Option<bool> get_reference_filter_ids(const std::string & filter_query,
filter_result_t& filter_result,
const std::string & collection_name) const;
Option<bool> validate_reference_filter(const std::string& filter_query) const;
Option<bool> validate_reference_filter(const std::string& filter_query) const;
Option<nlohmann::json> get(const std::string & id) const;
Option<std::string> remove(const std::string & id, bool remove_from_store = true);

View File

@ -641,11 +641,18 @@ struct reference_filter_result_t {
struct filter_result_t {
uint32_t count = 0;
uint32_t* docs = nullptr;
reference_filter_result_t* reference_filter_result = nullptr;
// Collection name -> Reference filter result
std::map<std::string, reference_filter_result_t*> reference_filter_results;
filter_result_t() {}
filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {}
~filter_result_t() {
delete[] docs;
delete[] reference_filter_result;
for (const auto &item: reference_filter_results) {
delete[] item.second;
}
}
};

View File

@ -467,16 +467,28 @@ private:
void numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const;
const uint32_t& context_ids_length,
const uint32_t* context_ids,
size_t& ids_len,
uint32_t*& ids) const;
bool field_is_indexed(const std::string& field_name) const;
Option<bool> do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name = "") const;
const std::string& collection_name = "",
const uint32_t& context_ids_length = 0,
const uint32_t* context_ids = nullptr) const;
Option<bool> rearranging_recursive_filter (filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name = "") const;
void aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
const int64_t& range_end_value,
uint32_t& filter_ids_length) const;
Option<bool> rearranging_recursive_filter(filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name = "") const;
Option<bool> recursive_filter(filter_node_t* const root,
filter_result_t& result,
@ -687,7 +699,8 @@ public:
Option<bool> do_reference_filtering_with_lock(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string & reference_helper_field_name) const;
const std::string& collection_name,
const std::string& reference_helper_field_name) const;
void refresh_schemas(const std::vector<field>& new_fields, const std::vector<field>& del_fields);

View File

@ -11,6 +11,17 @@ class num_tree_t {
private:
std::map<int64_t, void*> int64map;
[[nodiscard]] bool range_inclusive_contains(const int64_t& start, const int64_t& end, const uint32_t& id) const;
[[nodiscard]] bool contains(const int64_t& value, const uint32_t& id) const {
if (int64map.count(value) == 0) {
return false;
}
auto ids = int64map.at(value);
return ids_t::contains(ids, id);
}
public:
~num_tree_t();
@ -19,11 +30,27 @@ public:
void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len);
void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len);
void range_inclusive_contains(const int64_t& start, const int64_t& end,
const uint32_t& context_ids_length,
const uint32_t*& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const;
size_t get(int64_t value, std::vector<uint32_t>& geo_result_ids);
void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len);
void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len);
void remove(uint64_t value, uint32_t id);
size_t size();
void contains(const NUM_COMPARATOR& comparator, const int64_t& value,
const uint32_t& context_ids_length,
const uint32_t*& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const;
};

View File

@ -91,7 +91,9 @@ public:
static void merge(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids,
const uint32_t& context_ids_length = 0,
const uint32_t* context_ids = nullptr);
static void get_array_token_positions(
uint32_t id,

View File

@ -14,14 +14,15 @@ struct KV {
uint64_t key{};
uint64_t distinct_key{};
int64_t scores[3]{}; // match score + 2 custom attributes
reference_filter_result_t* reference_filter_result;
reference_filter_result_t* reference_filter_result = nullptr;
// to be used only in final aggregation
uint64_t* query_indices = nullptr;
KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores):
KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores,
reference_filter_result_t* reference_filter_result = nullptr):
match_score_index(match_score_index), query_index(queryIndex), array_index(0), key(key),
distinct_key(distinct_key) {
distinct_key(distinct_key), reference_filter_result(reference_filter_result) {
this->scores[0] = scores[0];
this->scores[1] = scores[1];
this->scores[2] = scores[2];

View File

@ -2519,8 +2519,6 @@ Option<bool> Collection::get_filter_ids(const std::string& filter_query, filter_
}
Option<std::string> Collection::get_reference_field(const std::string & collection_name) const {
std::shared_lock lock(mutex);
std::string reference_field_name;
for (auto const& pair: reference_fields) {
auto reference_pair = pair.second;
@ -2541,13 +2539,13 @@ Option<std::string> Collection::get_reference_field(const std::string & collecti
Option<bool> Collection::get_reference_filter_ids(const std::string & filter_query,
filter_result_t& filter_result,
const std::string & collection_name) const {
std::shared_lock lock(mutex);
auto reference_field_op = get_reference_field(collection_name);
if (!reference_field_op.ok()) {
return Option<bool>(reference_field_op.code(), reference_field_op.error());
}
std::shared_lock lock(mutex);
const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_op = filter::parse_filter_query(filter_query, search_schema,
@ -2558,7 +2556,7 @@ Option<bool> Collection::get_reference_filter_ids(const std::string & filter_que
// Reference helper field has the sequence id of other collection's documents.
auto field_name = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX;
auto filter_op = index->do_reference_filtering_with_lock(filter_tree_root, filter_result, field_name);
auto filter_op = index->do_reference_filtering_with_lock(filter_tree_root, filter_result, name, field_name);
if (!filter_op.ok()) {
return filter_op;
}
@ -2583,22 +2581,6 @@ Option<bool> Collection::validate_reference_filter(const std::string& filter_que
return Option<bool>(true);
}
Option<bool> Collection::validate_reference_filter(const std::string& filter_query) const {
std::shared_lock lock(mutex);
const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query(filter_query, search_schema,
store, doc_id_prefix, filter_tree_root);
if(!filter_op.ok()) {
return filter_op;
}
delete filter_tree_root;
return Option<bool>(true);
}
bool Collection::facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count,
const nlohmann::json &document, std::string &value) const {

View File

@ -418,38 +418,6 @@ Option<bool> toParseTree(std::queue<std::string>& postfix, filter_node_t*& root,
} else {
filter filter_exp;
// Expected value: $Collection(...)
bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')');
if (is_referenced_filter) {
size_t parenthesis_index = expression.find('(');
std::string collection_name = expression.substr(1, parenthesis_index - 1);
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(collection_name);
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + collection_name + "` not found.");
}
filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)};
filter_exp.referenced_collection_name = collection_name;
auto op = collection->validate_reference_filter(filter_exp.field_name);
if (!op.ok()) {
return Option<bool>(400, "Failed to parse reference filter on `" + collection_name +
"` collection: " + op.error());
}
} else {
Option<bool> toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix);
if (!toFilter_op.ok()) {
while(!nodeStack.empty()) {
auto filterNode = nodeStack.top();
delete filterNode;
nodeStack.pop();
}
return toFilter_op;
}
}
// Expected value: $Collection(...)
bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')');
if (is_referenced_filter) {

View File

@ -1451,11 +1451,18 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const {
const uint32_t& context_ids_length,
const uint32_t* context_ids,
size_t& ids_len,
uint32_t*& ids) const {
uint32_t* to_exclude_ids = nullptr;
size_t to_exclude_ids_len = 0;
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
if (context_ids_length != 0) {
num_tree->contains(EQUALS, value, context_ids_length, context_ids, to_exclude_ids_len, to_exclude_ids);
} else {
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
}
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
@ -1470,17 +1477,25 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
delete[] to_exclude_ids;
uint32_t* out = nullptr;
ids_len = ArrayUtils::or_scalar(ids, ids_len,
to_include_ids, to_include_ids_len, &out);
ids_len = ArrayUtils::or_scalar(ids, ids_len, to_include_ids, to_include_ids_len, &out);
delete[] ids;
delete[] to_include_ids;
ids = out;
}
bool Index::field_is_indexed(const std::string& field_name) const {
return search_index.count(field_name) != 0 ||
numerical_index.count(field_name) != 0 ||
geopoint_index.count(field_name) != 0;
}
Option<bool> Index::do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name) const {
const std::string& collection_name,
const uint32_t& context_ids_length,
const uint32_t* context_ids) const {
// auto begin = std::chrono::high_resolution_clock::now();
const filter a_filter = root->filter_exp;
@ -1492,13 +1507,46 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found.");
}
filter_result_t reference_filter_result;
auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name,
result,
reference_filter_result,
collection_name);
if (!reference_filter_op.ok()) {
return reference_filter_op;
}
if (context_ids_length != 0) {
std::vector<uint32_t> include_indexes;
include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count));
size_t context_index = 0, reference_result_index = 0;
while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) {
if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) {
include_indexes.push_back(reference_result_index);
context_index++;
reference_result_index++;
} else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) {
context_index++;
} else {
reference_result_index++;
}
}
result.count = include_indexes.size();
result.docs = new uint32_t[include_indexes.size()];
auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name];
result_references = new reference_filter_result_t[include_indexes.size()];
for (uint32_t i = 0; i < include_indexes.size(); i++) {
result.docs[i] = reference_filter_result.docs[include_indexes[i]];
result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]];
}
return Option(true);
}
result = reference_filter_result;
return Option(true);
}
@ -1511,18 +1559,26 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
std::sort(result_ids.begin(), result_ids.end());
result.docs = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), result.docs);
result.count = result_ids.size();
auto result_array = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), result_array);
if (context_ids_length != 0) {
uint32_t* out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_array, result_ids.size(), &out);
delete[] result_array;
result.docs = out;
return Option(true);
}
result.docs = result_array;
result.count = result_ids.size();
return Option(true);
}
bool has_search_index = search_index.count(a_filter.field_name) != 0 ||
numerical_index.count(a_filter.field_name) != 0 ||
geopoint_index.count(a_filter.field_name) != 0;
if (!has_search_index) {
if (!field_is_indexed(a_filter.field_name)) {
return Option(true);
}
@ -1540,13 +1596,25 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
int64_t range_end_value = (int64_t)std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
auto const range_end_value = (int64_t)std::stol(next_filter_value);
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], value,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
}
}
}
} else if (f.is_float()) {
@ -1560,12 +1628,25 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
}
}
}
} else if (f.is_bool()) {
@ -1575,9 +1656,15 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
if (a_filter.comparators[value_index] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, bool_int64, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[value_index], bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
}
}
value_index++;
@ -1652,6 +1739,14 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
// `geo_result_ids` will contain all IDs that are within approximately within query radius
// we still need to do another round of exact filtering on them
if (context_ids_length != 0) {
uint32_t *out = nullptr;
uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length,
&geo_result_ids[0], geo_result_ids.size(), &out);
geo_result_ids = std::vector<uint32_t>(out, out + count);
}
std::vector<uint32_t> exact_geo_result_ids;
if (f.is_single_geopoint()) {
@ -1739,7 +1834,7 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) {
// needs intersection + exact matching (unlike CONTAINS)
std::vector<uint32_t> result_id_vec;
posting_t::intersect(posting_lists, result_id_vec);
posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids);
if (result_id_vec.empty()) {
continue;
@ -1763,7 +1858,7 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
} else {
// CONTAINS
size_t before_size = f_id_buff.size();
posting_t::intersect(posting_lists, f_id_buff);
posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids);
if (f_id_buff.size() == before_size) {
continue;
}
@ -1811,6 +1906,17 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
result_ids = to_include_ids;
result_ids_len = to_include_ids_len;
if (context_ids_length != 0) {
uint32_t *out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_ids, result_ids_len, &out);
delete[] result_ids;
result.docs = out;
return Option(true);
}
}
result.docs = result_ids;
@ -1824,6 +1930,28 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
}
void Index::aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
const int64_t& range_end_value,
uint32_t& filter_ids_length) const {
if (comparator == RANGE_INCLUSIVE) {
num_tree->approx_range_inclusive_search_count(value, range_end_value, filter_ids_length);
return;
}
if (comparator == NOT_EQUALS) {
uint32_t to_exclude_ids_len = 0;
num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len);
auto all_ids_size = seq_ids->num_ids();
filter_ids_length += (all_ids_size - to_exclude_ids_len);
return;
}
num_tree->approx_search_count(comparator, value, filter_ids_length);
}
Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
uint32_t& filter_ids_length,
const std::string& collection_name) const {
@ -1861,13 +1989,94 @@ Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
return Option(true);
}
filter_result_t result;
auto filter_op = do_filtering(root, result, collection_name);
if (!filter_op.ok()) {
return filter_op;
auto a_filter = root->filter_exp;
if (a_filter.field_name == "id") {
filter_ids_length = a_filter.values.size();
return Option(true);
}
if (!field_is_indexed(a_filter.field_name)) {
return Option(true);
}
field f = search_schema.at(a_filter.field_name);
if (f.is_integer()) {
auto num_tree = numerical_index.at(f.name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
auto const value = (int64_t)std::stol(filter_value);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = (int64_t)std::stol(next_filter_value);
aproximate_numerical_match(num_tree, a_filter.comparators[fi], value, range_end_value,
filter_ids_length);
fi++;
} else {
aproximate_numerical_match(num_tree, a_filter.comparators[fi], value, 0, filter_ids_length);
}
}
} else if (f.is_float()) {
auto num_tree = numerical_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = float_to_int64_t(value);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
aproximate_numerical_match(num_tree, a_filter.comparators[fi], float_int64, range_end_value,
filter_ids_length);
fi++;
} else {
aproximate_numerical_match(num_tree, a_filter.comparators[fi], float_int64, 0, filter_ids_length);
}
}
} else if (f.is_bool()) {
auto num_tree = numerical_index.at(a_filter.field_name);
size_t value_index = 0;
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
aproximate_numerical_match(num_tree, a_filter.comparators[value_index], bool_int64, 0, filter_ids_length);
value_index++;
}
} else if (f.is_geopoint()) {
filter_ids_length = 100;
} else if (f.is_string()) {
art_tree* t = search_index.at(a_filter.field_name);
for (const std::string& filter_value : a_filter.values) {
Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators);
std::string str_token;
size_t token_index = 0;
while (tokenizer.next(str_token, token_index)) {
auto const leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
str_token.length()+1);
if (leaf == nullptr) {
continue;
}
filter_ids_length += posting_t::num_ids(leaf->values);
}
}
}
if (a_filter.apply_not_equals) {
auto all_ids_size = seq_ids->num_ids();
filter_ids_length = (all_ids_size - filter_ids_length);
}
filter_ids_length = result.count;
return Option(true);
}
@ -1884,19 +2093,23 @@ Option<bool> Index::rearranging_recursive_filter(filter_node_t* const filter_tre
}
void copy_reference_ids(filter_result_t& from, filter_result_t& to) {
if (to.count > 0 && from.reference_filter_result != nullptr && from.reference_filter_result->count > 0) {
to.reference_filter_result = new reference_filter_result_t[to.count];
if (to.count > 0 && !from.reference_filter_results.empty()) {
for (const auto &item: from.reference_filter_results) {
auto& from_reference_result = from.reference_filter_results[item.first];
auto& to_reference_result = to.reference_filter_results[item.first];
to_reference_result = new reference_filter_result_t[to.count];
size_t to_index = 0, from_index = 0;
while (to_index < to.count && from_index < from.count) {
if (to.docs[to_index] == from.docs[from_index]) {
to.reference_filter_result[to_index] = from.reference_filter_result[from_index];
to_index++;
from_index++;
} else if (to.docs[to_index] < from.docs[from_index]) {
to_index++;
} else {
from_index++;
size_t to_index = 0, from_index = 0;
while (to_index < to.count && from_index < from.count) {
if (to.docs[to_index] == from.docs[from_index]) {
to_reference_result[to_index] = from_reference_result[from_index];
to_index++;
from_index++;
} else if (to.docs[to_index] < from.docs[from_index]) {
to_index++;
} else {
from_index++;
}
}
}
}
@ -1938,8 +2151,8 @@ Option<bool> Index::recursive_filter(filter_node_t* const root,
}
result.docs = filtered_results;
if (l_result.reference_filter_result != nullptr || r_result.reference_filter_result != nullptr) {
copy_reference_ids(l_result.reference_filter_result != nullptr ? l_result : r_result, result);
if (!l_result.reference_filter_results.empty() || !r_result.reference_filter_results.empty()) {
copy_reference_ids(!l_result.reference_filter_results.empty() ? l_result : r_result, result);
}
return Option(true);
@ -1982,7 +2195,8 @@ Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root
Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string & reference_helper_field_name) const {
const std::string& collection_name,
const std::string& reference_helper_field_name) const {
std::shared_lock lock(mutex);
filter_result_t reference_filter_result;
@ -2002,15 +2216,17 @@ Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter
filter_result.count = reference_map.size();
filter_result.docs = new uint32_t[reference_map.size()];
filter_result.reference_filter_result = new reference_filter_result_t[reference_map.size()];
filter_result.reference_filter_results[collection_name] = new reference_filter_result_t[reference_map.size()];
size_t doc_index = 0;
for (auto &item: reference_map) {
filter_result.docs[doc_index] = item.first;
filter_result.reference_filter_result[doc_index].count = item.second.size();
filter_result.reference_filter_result[doc_index].docs = new uint32_t[item.second.size()];
std::copy(item.second.begin(), item.second.end(), filter_result.reference_filter_result[doc_index].docs);
auto& reference_result = filter_result.reference_filter_results[collection_name][doc_index];
reference_result.count = item.second.size();
reference_result.docs = new uint32_t[item.second.size()];
std::copy(item.second.begin(), item.second.end(), reference_result.docs);
doc_index++;
}
@ -2080,7 +2296,7 @@ void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
scores[1] = int64_t(1);
scores[2] = int64_t(1);
KV kv(searched_queries.size(), seq_id, distinct_id, 0, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, 0, scores, nullptr);
curated_topster->add(&kv);
}
}
@ -2582,7 +2798,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int64_t match_score_index = -1;
result_ids.push_back(seq_id);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
@ -2681,7 +2898,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {

View File

@ -43,6 +43,61 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i
*ids = out;
}
void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) {
if (int64map.empty()) {
return;
}
auto it_start = int64map.lower_bound(start); // iter values will be >= start
while (it_start != int64map.end() && it_start->first <= end) {
uint32_t val_ids = ids_t::num_ids(it_start->second);
ids_len += val_ids;
it_start++;
}
}
bool num_tree_t::range_inclusive_contains(const int64_t& start, const int64_t& end, const uint32_t& id) const {
if (int64map.empty()) {
return false;
}
auto it_start = int64map.lower_bound(start); // iter values will be >= start
while (it_start != int64map.end() && it_start->first <= end) {
if (ids_t::contains(it_start->second, id)) {
return true;
}
}
return false;
}
void num_tree_t::range_inclusive_contains(const int64_t& start, const int64_t& end,
const uint32_t& context_ids_length,
const uint32_t*& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const {
if (int64map.empty()) {
return;
}
std::vector<uint32_t> consolidated_ids;
consolidated_ids.reserve(context_ids_length);
for (uint32_t i = 0; i < context_ids_length; i++) {
if (range_inclusive_contains(start, end, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
}
}
uint32_t *out = nullptr;
result_ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
result_ids, result_ids_len, &out);
delete [] result_ids;
result_ids = out;
}
size_t num_tree_t::get(int64_t value, std::vector<uint32_t>& geo_result_ids) {
const auto& it = int64map.find(value);
if(it == int64map.end()) {
@ -132,6 +187,54 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids
}
}
void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) {
if (int64map.empty()) {
return;
}
if (comparator == EQUALS) {
const auto& it = int64map.find(value);
if (it != int64map.end()) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
}
} else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
if (iter_ge_value == int64map.end()) {
return;
}
if (comparator == GREATER_THAN && iter_ge_value->first == value) {
iter_ge_value++;
}
while (iter_ge_value != int64map.end()) {
uint32_t val_ids = ids_t::num_ids(iter_ge_value->second);
ids_len += val_ids;
iter_ge_value++;
}
} else if (comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
auto it = int64map.begin();
while (it != iter_ge_value) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
it++;
}
// for LESS_THAN_EQUALS, check if last iter entry is equal to value
if (it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
}
}
}
void num_tree_t::remove(uint64_t value, uint32_t id) {
if(int64map.count(value) != 0) {
void* arr = int64map[value];
@ -146,6 +249,75 @@ void num_tree_t::remove(uint64_t value, uint32_t id) {
}
}
void num_tree_t::contains(const NUM_COMPARATOR& comparator, const int64_t& value,
const uint32_t& context_ids_length,
const uint32_t*& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const {
if (int64map.empty()) {
return;
}
std::vector<uint32_t> consolidated_ids;
consolidated_ids.reserve(context_ids_length);
for (uint32_t i = 0; i < context_ids_length; i++) {
if (comparator == EQUALS) {
if (contains(value, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
}
} else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
if (iter_ge_value == int64map.end()) {
continue;
}
if (comparator == GREATER_THAN && iter_ge_value->first == value) {
iter_ge_value++;
}
while (iter_ge_value != int64map.end()) {
if (contains(iter_ge_value->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
iter_ge_value++;
}
} else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
auto it = int64map.begin();
while (it != iter_ge_value) {
if (contains(it->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
it++;
}
// for LESS_THAN_EQUALS, check if last iter entry is equal to value
if (it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) {
if (contains(it->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
}
}
}
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
uint32_t *out = nullptr;
result_ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
result_ids, result_ids_len, &out);
delete[] result_ids;
result_ids = out;
}
size_t num_tree_t::size() {
return int64map.size();
}

View File

@ -386,7 +386,32 @@ void posting_t::merge(const std::vector<void*>& raw_posting_lists, std::vector<u
}
}
void posting_t::intersect(const std::vector<void*>& raw_posting_lists, std::vector<uint32_t>& result_ids) {
void posting_t::intersect(const std::vector<void*>& raw_posting_lists, std::vector<uint32_t>& result_ids,
const uint32_t& context_ids_length,
const uint32_t* context_ids) {
if (context_ids_length != 0) {
if (raw_posting_lists.empty()) {
return;
}
for (uint32_t i = 0; i < context_ids_length; i++) {
bool is_present = true;
for (auto const& raw_posting_list: raw_posting_lists) {
if (!contains(raw_posting_list, context_ids[i])) {
is_present = false;
break;
}
}
if (is_present) {
result_ids.push_back(context_ids[i]);
}
}
return;
}
// we will have to convert the compact posting list (if any) to full form
std::vector<posting_list_t*> plists;
std::vector<posting_list_t*> expanded_plists;

View File

@ -651,11 +651,11 @@ TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) {
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error());
req_params["include_fields"] = "$foo(bar)";
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Referenced collection `foo` not found.", search_op.error());
// req_params["include_fields"] = "$foo(bar)";
// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
// ASSERT_FALSE(search_op.ok());
// ASSERT_EQ("Referenced collection `foo` not found.", search_op.error());
//
// req_params["include_fields"] = "$Customers(bar)";
// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
// ASSERT_TRUE(search_op.ok());