Merge pull request #921 from happy-san/v0.25

Joins -- part 2
This commit is contained in:
Kishore Nallan 2023-03-11 16:06:03 +05:30 committed by GitHub
commit 26a783e52b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1291 additions and 203 deletions

View File

@ -6,4 +6,4 @@ build --cxxopt="-std=c++17"
test --jobs=6
build --enable_platform_specific_config
build:linux --action_env=BAZEL_LINKLIBS="-l%:libstdc++.a -l%:libgcc.a"
build:linux --action_env=BAZEL_LINKLIBS="-l%:libstdc++.a -l%:libgcc.a"

1
.gitignore vendored
View File

@ -15,4 +15,3 @@ typesense-server-data/
.clwb/.bazelproject
.vscode/settings.json
/onnxruntime-prefix

View File

@ -255,4 +255,4 @@ target_sources(search PRIVATE ${ONNX_EXT_SRC_FILES})
add_dependencies(typesense-server onnxruntime_ext)
add_dependencies(typesense-test onnxruntime_ext)
add_dependencies(benchmark onnxruntime_ext)
add_dependencies(search onnxruntime_ext)
add_dependencies(search onnxruntime_ext)

View File

@ -268,6 +268,8 @@ private:
Option<std::string> get_reference_field(const std::string & collection_name) const;
public:
enum {MAX_ARRAY_MATCHES = 5};
@ -455,11 +457,13 @@ public:
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
Option<std::string> get_reference_field(const std::string & collection_name) const;
/// Get approximate count of docs matching a reference filter on foo collection when $foo(...) filter is encountered.
Option<bool> get_approximate_reference_filter_ids(const std::string& filter_query,
uint32_t& filter_ids_length) const;
Option<bool> get_reference_filter_ids(const std::string & filter_query,
Option<bool> get_reference_filter_ids(const std::string& filter_query,
filter_result_t& filter_result,
const std::string & collection_name) const;
const std::string& collection_name) const;
Option<bool> validate_reference_filter(const std::string& filter_query) const;

View File

@ -513,7 +513,7 @@ struct filter {
bool apply_not_equals = false;
// Would store `Foo` in case of a filter expression like `$Foo(bar := baz)`
std::string referenced_collection_name;
std::string referenced_collection_name = "";
static const std::string RANGE_OPERATOR() {
return "..";
@ -594,12 +594,6 @@ struct filter {
filter_node_t*& root);
};
struct filter_tree_metrics {
int filter_exp_count;
int and_operator_count;
int or_operator_count;
};
struct filter_node_t {
filter filter_exp;
FILTER_OPERATOR filter_operator;
@ -622,7 +616,6 @@ struct filter_node_t {
right(right) {}
~filter_node_t() {
delete metrics;
delete left;
delete right;
}
@ -640,12 +633,37 @@ struct reference_filter_result_t {
struct filter_result_t {
uint32_t count = 0;
uint32_t* docs = nullptr;
reference_filter_result_t* reference_filter_result = nullptr;
// Collection name -> Reference filter result
std::map<std::string, reference_filter_result_t*> reference_filter_results;
filter_result_t() {}
filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {}
filter_result_t& operator=(filter_result_t&& obj) noexcept {
if (&obj == this)
return *this;
count = obj.count;
docs = obj.docs;
reference_filter_results = std::map(obj.reference_filter_results);
obj.docs = nullptr;
obj.reference_filter_results.clear();
return *this;
}
~filter_result_t() {
delete[] docs;
delete[] reference_filter_result;
for (const auto &item: reference_filter_results) {
delete[] item.second;
}
}
static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result);
static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result);
};
namespace sort_field_const {

View File

@ -467,28 +467,37 @@ private:
void numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const;
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& ids_len,
uint32_t*& ids) const;
bool field_is_indexed(const std::string& field_name) const;
Option<bool> do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name = "") const;
const std::string& collection_name = "",
const uint32_t& context_ids_length = 0,
uint32_t* const& context_ids = nullptr) const;
Option<bool> rearranging_recursive_filter (filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name = "") const;
void aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
const int64_t& range_end_value,
uint32_t& filter_ids_length) const;
Option<bool> recursive_filter(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name = "") const;
Option<bool> adaptive_filter(filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name = "") const;
Option<bool> rearrange_filter_tree(filter_node_t* const root,
uint32_t& filter_ids_length,
const std::string& collection_name = "") const;
/// Traverses through filter tree to get the filter_result.
///
/// \param filter_tree_root
/// \param filter_result
/// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection.
/// \param context_ids_length Number of docs matching the search query.
/// \param context_ids Array of doc ids matching the search query.
Option<bool> recursive_filter(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string& collection_name = "",
const uint32_t& context_ids_length = 0,
uint32_t* const& context_ids = nullptr) const;
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
@ -685,9 +694,28 @@ public:
filter_result_t& filter_result,
const std::string& collection_name = "") const;
/// Traverses through filter tree and gets an approximate doc count for each filter. Also arranges the children of
/// each operator in ascending order based on their approx doc count.
///
/// \param filter_tree_root
/// \param approx_filter_ids_length Approximate count of docs that would match the whole filter_by clause.
/// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection.
Option<bool> rearrange_filter_tree(filter_node_t* const filter_tree_root,
uint32_t& approx_filter_ids_length,
const std::string& collection_name = "") const;
Option<bool> _approximate_filter_ids(const filter& a_filter,
uint32_t& filter_ids_length,
const std::string& collection_name = "") const;
Option<bool> do_reference_filtering_with_lock(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string & reference_helper_field_name) const;
const std::string& collection_name,
const std::string& reference_helper_field_name) const;
/// Get approximate count of docs matching a reference filter on foo collection when $foo(...) filter is encountered.
Option<bool> get_approximate_reference_filter_ids_with_lock(filter_node_t* const filter_tree_root,
uint32_t& filter_ids_length) const;
void refresh_schemas(const std::vector<field>& new_fields, const std::vector<field>& del_fields);

View File

@ -11,6 +11,17 @@ class num_tree_t {
private:
std::map<int64_t, void*> int64map;
[[nodiscard]] bool range_inclusive_contains(const int64_t& start, const int64_t& end, const uint32_t& id) const;
[[nodiscard]] bool contains(const int64_t& value, const uint32_t& id) const {
if (int64map.count(value) == 0) {
return false;
}
auto ids = int64map.at(value);
return ids_t::contains(ids, id);
}
public:
~num_tree_t();
@ -19,11 +30,27 @@ public:
void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len);
void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len);
void range_inclusive_contains(const int64_t& start, const int64_t& end,
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const;
size_t get(int64_t value, std::vector<uint32_t>& geo_result_ids);
void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len);
void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len);
void remove(uint64_t value, uint32_t id);
size_t size();
void contains(const NUM_COMPARATOR& comparator, const int64_t& value,
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const;
};

View File

@ -91,7 +91,9 @@ public:
static void merge(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(const std::vector<void*>& posting_lists, std::vector<uint32_t>& result_ids,
const uint32_t& context_ids_length = 0,
const uint32_t* context_ids = nullptr);
static void get_array_token_positions(
uint32_t id,

View File

@ -14,14 +14,15 @@ struct KV {
uint64_t key{};
uint64_t distinct_key{};
int64_t scores[3]{}; // match score + 2 custom attributes
reference_filter_result_t* reference_filter_result;
reference_filter_result_t* reference_filter_result = nullptr;
// to be used only in final aggregation
uint64_t* query_indices = nullptr;
KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores):
KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores,
reference_filter_result_t* reference_filter_result = nullptr):
match_score_index(match_score_index), query_index(queryIndex), array_index(0), key(key),
distinct_key(distinct_key) {
distinct_key(distinct_key), reference_filter_result(reference_filter_result) {
this->scores[0] = scores[0];
this->scores[1] = scores[1];
this->scores[2] = scores[2];

View File

@ -246,6 +246,7 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::reference] = coll_field.reference;
}
fields_arr.push_back(field_json);
}
@ -1515,7 +1516,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
// for grouping we have to re-aggregate
Topster& topster = *search_params->topster;
Topster& curated_topster = *search_params->curated_topster;
@ -2509,19 +2509,16 @@ Option<bool> Collection::get_filter_ids(const std::string& filter_query, filter_
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query(filter_query, search_schema,
store, doc_id_prefix, filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
if(!filter_op.ok()) {
return filter_op;
}
index->do_filtering_with_lock(filter_tree_root, filter_result, name);
delete filter_tree_root;
return Option<bool>(true);
return index->do_filtering_with_lock(filter_tree_root, filter_result, name);
}
Option<std::string> Collection::get_reference_field(const std::string & collection_name) const {
std::shared_lock lock(mutex);
std::string reference_field_name;
for (auto const& pair: reference_fields) {
auto reference_pair = pair.second;
@ -2539,33 +2536,46 @@ Option<std::string> Collection::get_reference_field(const std::string & collecti
return Option(reference_field_name);
}
Option<bool> Collection::get_reference_filter_ids(const std::string & filter_query,
filter_result_t& filter_result,
const std::string & collection_name) const {
auto reference_field_op = get_reference_field(collection_name);
if (!reference_field_op.ok()) {
return Option<bool>(reference_field_op.code(), reference_field_op.error());
}
Option<bool> Collection::get_approximate_reference_filter_ids(const std::string& filter_query,
uint32_t& filter_ids_length) const {
std::shared_lock lock(mutex);
const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_op = filter::parse_filter_query(filter_query, search_schema,
store, doc_id_prefix, filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
if(!parse_op.ok()) {
return parse_op;
}
return index->get_approximate_reference_filter_ids_with_lock(filter_tree_root, filter_ids_length);
}
Option<bool> Collection::get_reference_filter_ids(const std::string & filter_query,
filter_result_t& filter_result,
const std::string & collection_name) const {
std::shared_lock lock(mutex);
auto reference_field_op = get_reference_field(collection_name);
if (!reference_field_op.ok()) {
return Option<bool>(reference_field_op.code(), reference_field_op.error());
}
const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_op = filter::parse_filter_query(filter_query, search_schema,
store, doc_id_prefix, filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
if(!parse_op.ok()) {
return parse_op;
}
// Reference helper field has the sequence id of other collection's documents.
auto field_name = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX;
auto filter_op = index->do_reference_filtering_with_lock(filter_tree_root, filter_result, field_name);
if (!filter_op.ok()) {
return filter_op;
}
delete filter_tree_root;
return Option<bool>(true);
return index->do_reference_filtering_with_lock(filter_tree_root, filter_result, name, field_name);
}
Option<bool> Collection::validate_reference_filter(const std::string& filter_query) const {
@ -2575,7 +2585,6 @@ Option<bool> Collection::validate_reference_filter(const std::string& filter_que
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query(filter_query, search_schema,
store, doc_id_prefix, filter_tree_root);
if(!filter_op.ok()) {
return filter_op;
}

View File

@ -384,37 +384,36 @@ Option<bool> toFilter(const std::string expression,
Option<bool> toParseTree(std::queue<std::string>& postfix, filter_node_t*& root,
const tsl::htrie_map<char, field>& search_schema,
const Store* store,
const std::string& doc_id_prefix,
int& and_operator_count,
int& or_operator_count) {
const std::string& doc_id_prefix) {
std::stack<filter_node_t*> nodeStack;
bool is_successful = true;
std::string error_message;
while (!postfix.empty()) {
const std::string expression = postfix.front();
postfix.pop();
filter_node_t* filter_node = nullptr;
filter_node_t *filter_node = nullptr;
if (isOperator(expression)) {
auto message = "Could not parse the filter query: unbalanced `" + expression + "` operands.";
if (nodeStack.empty()) {
return Option<bool>(400, message);
is_successful = false;
error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands.";
break;
}
auto operandB = nodeStack.top();
nodeStack.pop();
if (nodeStack.empty()) {
delete operandB;
return Option<bool>(400, message);
is_successful = false;
error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands.";
break;
}
auto operandA = nodeStack.top();
nodeStack.pop();
expression == "&&" ? and_operator_count++ : or_operator_count++;
filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB);
} else {
filter filter_exp;
<<<<<<< HEAD
// Expected value: $Collection(...)
bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')');
@ -422,10 +421,12 @@ Option<bool> toParseTree(std::queue<std::string>& postfix, filter_node_t*& root,
size_t parenthesis_index = expression.find('(');
std::string collection_name = expression.substr(1, parenthesis_index - 1);
auto& cm = CollectionManager::get_instance();
auto &cm = CollectionManager::get_instance();
auto collection = cm.get_collection(collection_name);
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + collection_name + "` not found.");
is_successful = false;
error_message = "Referenced collection `" + collection_name + "` not found.";
break;
}
filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)};
@ -433,18 +434,17 @@ Option<bool> toParseTree(std::queue<std::string>& postfix, filter_node_t*& root,
auto op = collection->validate_reference_filter(filter_exp.field_name);
if (!op.ok()) {
return Option<bool>(400, "Failed to parse reference filter on `" + collection_name +
"` collection: " + op.error());
is_successful = false;
error_message = "Failed to parse reference filter on `" + collection_name + "` collection: " +
op.error();
break;
}
} else {
Option<bool> toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix);
if (!toFilter_op.ok()) {
while(!nodeStack.empty()) {
auto filterNode = nodeStack.top();
delete filterNode;
nodeStack.pop();
}
return toFilter_op;
is_successful = false;
error_message = toFilter_op.error();
break;
}
}
@ -454,11 +454,21 @@ Option<bool> toParseTree(std::queue<std::string>& postfix, filter_node_t*& root,
nodeStack.push(filter_node);
}
if (!is_successful) {
while (!nodeStack.empty()) {
auto filterNode = nodeStack.top();
delete filterNode;
nodeStack.pop();
}
return Option<bool>(400, error_message);
}
if (nodeStack.empty()) {
return Option<bool>(400, "Filter query cannot be empty.");
}
root = nodeStack.top();
return Option<bool>(true);
}
@ -489,22 +499,15 @@ Option<bool> filter::parse_filter_query(const std::string& filter_query,
return toPostfix_op;
}
int postfix_size = (int) postfix.size(), and_operator_count = 0, or_operator_count = 0;
Option<bool> toParseTree_op = toParseTree(postfix,
root,
search_schema,
store,
doc_id_prefix,
and_operator_count,
or_operator_count);
doc_id_prefix);
if (!toParseTree_op.ok()) {
return toParseTree_op;
}
root->metrics = new filter_tree_metrics{static_cast<int>(postfix_size - (and_operator_count + or_operator_count)),
and_operator_count,
or_operator_count};
return Option<bool>(true);
}
@ -980,3 +983,68 @@ void field::compact_nested_fields(tsl::htrie_map<char, field>& nested_fields) {
nested_fields.erase_prefix(field_name + ".");
}
}
void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) {
auto lenA = a.count, lenB = b.count;
if (lenA == 0 || lenB == 0) {
return;
}
result.docs = new uint32_t[std::min(lenA, lenB)];
auto A = a.docs, B = b.docs, out = result.docs;
const uint32_t *endA = A + lenA;
const uint32_t *endB = B + lenB;
for (auto const& item: a.reference_filter_results) {
if (result.reference_filter_results.count(item.first) == 0) {
result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)];
}
}
for (auto const& item: b.reference_filter_results) {
if (result.reference_filter_results.count(item.first) == 0) {
result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)];
}
}
while (true) {
while (*A < *B) {
SKIP_FIRST_COMPARE:
if (++A == endA) {
result.count = out - result.docs;
return;
}
}
while (*A > *B) {
if (++B == endB) {
result.count = out - result.docs;
return;
}
}
if (*A == *B) {
*out = *A;
for (auto const& item: a.reference_filter_results) {
auto& reference = result.reference_filter_results[item.first][out - result.docs];
reference.count = item.second[A - a.docs].count;
reference.docs = new uint32_t[reference.count];
memcpy(reference.docs, item.second[A - a.docs].docs, reference.count * sizeof(uint32_t));
}
for (auto const& item: b.reference_filter_results) {
auto& reference = result.reference_filter_results[item.first][out - result.docs];
reference.count = item.second[B - b.docs].count;
reference.docs = new uint32_t[reference.count];
memcpy(reference.docs, item.second[B - b.docs].docs, reference.count * sizeof(uint32_t));
}
out++;
if (++A == endA || ++B == endB) {
result.count = out - result.docs;
return;
}
} else {
goto SKIP_FIRST_COMPARE;
}
}
}

View File

@ -1451,11 +1451,18 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const {
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& ids_len,
uint32_t*& ids) const {
uint32_t* to_exclude_ids = nullptr;
size_t to_exclude_ids_len = 0;
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
if (context_ids_length != 0) {
num_tree->contains(EQUALS, value, context_ids_length, context_ids, to_exclude_ids_len, to_exclude_ids);
} else {
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
}
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
@ -1470,17 +1477,25 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
delete[] to_exclude_ids;
uint32_t* out = nullptr;
ids_len = ArrayUtils::or_scalar(ids, ids_len,
to_include_ids, to_include_ids_len, &out);
ids_len = ArrayUtils::or_scalar(ids, ids_len, to_include_ids, to_include_ids_len, &out);
delete[] ids;
delete[] to_include_ids;
ids = out;
}
bool Index::field_is_indexed(const std::string& field_name) const {
return search_index.count(field_name) != 0 ||
numerical_index.count(field_name) != 0 ||
geopoint_index.count(field_name) != 0;
}
Option<bool> Index::do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name) const {
const std::string& collection_name,
const uint32_t& context_ids_length,
uint32_t* const& context_ids) const {
// auto begin = std::chrono::high_resolution_clock::now();
const filter a_filter = root->filter_exp;
@ -1492,13 +1507,46 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found.");
}
filter_result_t reference_filter_result;
auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name,
result,
reference_filter_result,
collection_name);
if (!reference_filter_op.ok()) {
return reference_filter_op;
}
if (context_ids_length != 0) {
std::vector<uint32_t> include_indexes;
include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count));
size_t context_index = 0, reference_result_index = 0;
while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) {
if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) {
include_indexes.push_back(reference_result_index);
context_index++;
reference_result_index++;
} else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) {
context_index++;
} else {
reference_result_index++;
}
}
result.count = include_indexes.size();
result.docs = new uint32_t[include_indexes.size()];
auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name];
result_references = new reference_filter_result_t[include_indexes.size()];
for (uint32_t i = 0; i < include_indexes.size(); i++) {
result.docs[i] = reference_filter_result.docs[include_indexes[i]];
result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]];
}
return Option(true);
}
result = std::move(reference_filter_result);
return Option(true);
}
@ -1511,18 +1559,26 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
std::sort(result_ids.begin(), result_ids.end());
result.docs = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), result.docs);
result.count = result_ids.size();
auto result_array = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), result_array);
if (context_ids_length != 0) {
uint32_t* out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_array, result_ids.size(), &out);
delete[] result_array;
result.docs = out;
return Option(true);
}
result.docs = result_array;
result.count = result_ids.size();
return Option(true);
}
bool has_search_index = search_index.count(a_filter.field_name) != 0 ||
numerical_index.count(a_filter.field_name) != 0 ||
geopoint_index.count(a_filter.field_name) != 0;
if (!has_search_index) {
if (!field_is_indexed(a_filter.field_name)) {
return Option(true);
}
@ -1540,13 +1596,25 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
int64_t range_end_value = (int64_t)std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
auto const range_end_value = (int64_t)std::stol(next_filter_value);
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], value,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
}
}
}
} else if (f.is_float()) {
@ -1560,12 +1628,25 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
}
}
}
} else if (f.is_bool()) {
@ -1575,9 +1656,15 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
if (a_filter.comparators[value_index] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, bool_int64, result_ids, result_ids_len);
numeric_not_equals_filter(num_tree, bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[value_index], bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
}
}
value_index++;
@ -1652,6 +1739,14 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
// `geo_result_ids` will contain all IDs that are within approximately within query radius
// we still need to do another round of exact filtering on them
if (context_ids_length != 0) {
uint32_t *out = nullptr;
uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length,
&geo_result_ids[0], geo_result_ids.size(), &out);
geo_result_ids = std::vector<uint32_t>(out, out + count);
}
std::vector<uint32_t> exact_geo_result_ids;
if (f.is_single_geopoint()) {
@ -1739,7 +1834,7 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) {
// needs intersection + exact matching (unlike CONTAINS)
std::vector<uint32_t> result_id_vec;
posting_t::intersect(posting_lists, result_id_vec);
posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids);
if (result_id_vec.empty()) {
continue;
@ -1763,7 +1858,7 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
} else {
// CONTAINS
size_t before_size = f_id_buff.size();
posting_t::intersect(posting_lists, f_id_buff);
posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids);
if (f_id_buff.size() == before_size) {
continue;
}
@ -1811,6 +1906,17 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
result_ids = to_include_ids;
result_ids_len = to_include_ids_len;
if (context_ids_length != 0) {
uint32_t *out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_ids, result_ids_len, &out);
delete[] result_ids;
result.docs = out;
return Option(true);
}
}
result.docs = result_ids;
@ -1824,8 +1930,132 @@ Option<bool> Index::do_filtering(filter_node_t* const root,
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
}
void Index::aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
const int64_t& range_end_value,
uint32_t& filter_ids_length) const {
if (comparator == RANGE_INCLUSIVE) {
num_tree->approx_range_inclusive_search_count(value, range_end_value, filter_ids_length);
return;
}
if (comparator == NOT_EQUALS) {
uint32_t to_exclude_ids_len = 0;
num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len);
auto all_ids_size = seq_ids->num_ids();
filter_ids_length += (all_ids_size - to_exclude_ids_len);
return;
}
num_tree->approx_search_count(comparator, value, filter_ids_length);
}
Option<bool> Index::_approximate_filter_ids(const filter& a_filter,
uint32_t& filter_ids_length,
const std::string& collection_name) const {
if (!a_filter.referenced_collection_name.empty()) {
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(a_filter.referenced_collection_name);
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found.");
}
return collection->get_approximate_reference_filter_ids(a_filter.field_name, filter_ids_length);
}
if (a_filter.field_name == "id") {
filter_ids_length = a_filter.values.size();
return Option(true);
}
if (!field_is_indexed(a_filter.field_name)) {
return Option(true);
}
field f = search_schema.at(a_filter.field_name);
if (f.is_integer()) {
auto num_tree = numerical_index.at(f.name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
auto const value = (int64_t)std::stol(filter_value);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = (int64_t)std::stol(next_filter_value);
aproximate_numerical_match(num_tree, a_filter.comparators[fi], value, range_end_value,
filter_ids_length);
fi++;
} else {
aproximate_numerical_match(num_tree, a_filter.comparators[fi], value, 0, filter_ids_length);
}
}
} else if (f.is_float()) {
auto num_tree = numerical_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = float_to_int64_t(value);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
aproximate_numerical_match(num_tree, a_filter.comparators[fi], float_int64, range_end_value,
filter_ids_length);
fi++;
} else {
aproximate_numerical_match(num_tree, a_filter.comparators[fi], float_int64, 0, filter_ids_length);
}
}
} else if (f.is_bool()) {
auto num_tree = numerical_index.at(a_filter.field_name);
size_t value_index = 0;
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
aproximate_numerical_match(num_tree, a_filter.comparators[value_index], bool_int64, 0, filter_ids_length);
value_index++;
}
} else if (f.is_geopoint()) {
filter_ids_length = 100;
} else if (f.is_string()) {
art_tree* t = search_index.at(a_filter.field_name);
for (const std::string& filter_value : a_filter.values) {
Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators);
std::string str_token;
size_t token_index = 0;
while (tokenizer.next(str_token, token_index)) {
auto const leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
str_token.length()+1);
if (leaf == nullptr) {
continue;
}
filter_ids_length += posting_t::num_ids(leaf->values);
}
}
}
if (a_filter.apply_not_equals) {
auto all_ids_size = seq_ids->num_ids();
filter_ids_length = (all_ids_size - filter_ids_length);
}
return Option(true);
}
Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
uint32_t& filter_ids_length,
uint32_t& approx_filter_ids_length,
const std::string& collection_name) const {
if (root == nullptr) {
return Option(true);
@ -1849,9 +2079,9 @@ Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
}
if (root->filter_operator == AND) {
filter_ids_length = std::min(l_filter_ids_length, r_filter_ids_length);
approx_filter_ids_length = std::min(l_filter_ids_length, r_filter_ids_length);
} else {
filter_ids_length = l_filter_ids_length + r_filter_ids_length;
approx_filter_ids_length = l_filter_ids_length + r_filter_ids_length;
}
if (l_filter_ids_length > r_filter_ids_length) {
@ -1861,42 +2091,28 @@ Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
return Option(true);
}
filter_result_t result;
auto filter_op = do_filtering(root, result, collection_name);
if (!filter_op.ok()) {
return filter_op;
}
filter_ids_length = result.count;
_approximate_filter_ids(root->filter_exp, approx_filter_ids_length, collection_name);
return Option(true);
}
Option<bool> Index::rearranging_recursive_filter(filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name) const {
uint32_t filter_ids_length = 0;
auto rearrange_op = rearrange_filter_tree(filter_tree_root, filter_ids_length, collection_name);
if (!rearrange_op.ok()) {
return rearrange_op;
}
return recursive_filter(filter_tree_root, result, collection_name);
}
void copy_reference_ids(filter_result_t& from, filter_result_t& to) {
if (to.count > 0 && from.reference_filter_result != nullptr && from.reference_filter_result->count > 0) {
to.reference_filter_result = new reference_filter_result_t[to.count];
if (to.count > 0 && !from.reference_filter_results.empty()) {
for (const auto &item: from.reference_filter_results) {
auto& from_reference_result = from.reference_filter_results[item.first];
auto& to_reference_result = to.reference_filter_results[item.first];
to_reference_result = new reference_filter_result_t[to.count];
size_t to_index = 0, from_index = 0;
while (to_index < to.count && from_index < from.count) {
if (to.docs[to_index] == from.docs[from_index]) {
to.reference_filter_result[to_index] = from.reference_filter_result[from_index];
to_index++;
from_index++;
} else if (to.docs[to_index] < from.docs[from_index]) {
to_index++;
} else {
from_index++;
size_t to_index = 0, from_index = 0;
while (to_index < to.count && from_index < from.count) {
if (to.docs[to_index] == from.docs[from_index]) {
to_reference_result[to_index] = from_reference_result[from_index];
to_index++;
from_index++;
} else if (to.docs[to_index] < from.docs[from_index]) {
to_index++;
} else {
from_index++;
}
}
}
}
@ -1904,7 +2120,9 @@ void copy_reference_ids(filter_result_t& from, filter_result_t& to) {
Option<bool> Index::recursive_filter(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name) const {
const std::string& collection_name,
const uint32_t& context_ids_length,
uint32_t* const& context_ids) const {
if (root == nullptr) {
return Option(true);
}
@ -1912,7 +2130,7 @@ Option<bool> Index::recursive_filter(filter_node_t* const root,
if (root->isOperator) {
filter_result_t l_result;
if (root->left != nullptr) {
auto filter_op = recursive_filter(root->left, l_result , collection_name);
auto filter_op = recursive_filter(root->left, l_result , collection_name, context_ids_length, context_ids);
if (!filter_op.ok()) {
return filter_op;
}
@ -1920,51 +2138,30 @@ Option<bool> Index::recursive_filter(filter_node_t* const root,
filter_result_t r_result;
if (root->right != nullptr) {
auto filter_op = recursive_filter(root->right, r_result , collection_name);
auto filter_op = recursive_filter(root->right, r_result , collection_name, context_ids_length, context_ids);
if (!filter_op.ok()) {
return filter_op;
}
}
uint32_t* filtered_results = nullptr;
if (root->filter_operator == AND) {
result.count = ArrayUtils::and_scalar(
l_result.docs, l_result.count, r_result.docs,
r_result.count, &filtered_results);
filter_result_t::and_filter_results(l_result, r_result, result);
} else {
uint32_t* filtered_results = nullptr;
result.count = ArrayUtils::or_scalar(
l_result.docs, l_result.count, r_result.docs,
r_result.count, &filtered_results);
}
result.docs = filtered_results;
if (l_result.reference_filter_result != nullptr || r_result.reference_filter_result != nullptr) {
copy_reference_ids(l_result.reference_filter_result != nullptr ? l_result : r_result, result);
result.docs = filtered_results;
if (!l_result.reference_filter_results.empty() || !r_result.reference_filter_results.empty()) {
copy_reference_ids(!l_result.reference_filter_results.empty() ? l_result : r_result, result);
}
}
return Option(true);
}
return do_filtering(root, result, collection_name);
}
Option<bool> Index::adaptive_filter(filter_node_t* const filter_tree_root,
filter_result_t& result,
const std::string& collection_name) const {
if (filter_tree_root == nullptr) {
return Option(true);
}
auto metrics = filter_tree_root->metrics;
if (metrics != nullptr &&
metrics->filter_exp_count > 2 &&
metrics->and_operator_count > 0 &&
// If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree.
((float) metrics->or_operator_count / (float) metrics->and_operator_count < 0.5)) {
return rearranging_recursive_filter(filter_tree_root, result, collection_name);
} else {
return recursive_filter(filter_tree_root, result, collection_name);
}
return do_filtering(root, result, collection_name, context_ids_length, context_ids);
}
Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root,
@ -1972,7 +2169,7 @@ Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root
const std::string& collection_name) const {
std::shared_lock lock(mutex);
auto filter_op = adaptive_filter(filter_tree_root, filter_result, collection_name);
auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name);
if (!filter_op.ok()) {
return filter_op;
}
@ -1982,11 +2179,12 @@ Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root
Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string & reference_helper_field_name) const {
const std::string& collection_name,
const std::string& reference_helper_field_name) const {
std::shared_lock lock(mutex);
filter_result_t reference_filter_result;
auto filter_op = adaptive_filter(filter_tree_root, reference_filter_result);
auto filter_op = recursive_filter(filter_tree_root, reference_filter_result);
if (!filter_op.ok()) {
return filter_op;
}
@ -2002,21 +2200,30 @@ Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter
filter_result.count = reference_map.size();
filter_result.docs = new uint32_t[reference_map.size()];
filter_result.reference_filter_result = new reference_filter_result_t[reference_map.size()];
filter_result.reference_filter_results[collection_name] = new reference_filter_result_t[reference_map.size()];
size_t doc_index = 0;
for (auto &item: reference_map) {
filter_result.docs[doc_index] = item.first;
filter_result.reference_filter_result[doc_index].count = item.second.size();
filter_result.reference_filter_result[doc_index].docs = new uint32_t[item.second.size()];
std::copy(item.second.begin(), item.second.end(), filter_result.reference_filter_result[doc_index].docs);
auto& reference_result = filter_result.reference_filter_results[collection_name][doc_index];
reference_result.count = item.second.size();
reference_result.docs = new uint32_t[item.second.size()];
std::copy(item.second.begin(), item.second.end(), reference_result.docs);
doc_index++;
}
return Option(true);
}
Option<bool> Index::get_approximate_reference_filter_ids_with_lock(filter_node_t* const filter_tree_root,
uint32_t& filter_ids_length) const {
std::shared_lock lock(mutex);
return rearrange_filter_tree(filter_tree_root, filter_ids_length);
}
Option<bool> Index::run_search(search_args* search_params, const std::string& collection_name) {
return search(search_params->field_query_tokens,
search_params->search_fields,
@ -2080,7 +2287,7 @@ void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
scores[1] = int64_t(1);
scores[2] = int64_t(1);
KV kv(searched_queries.size(), seq_id, distinct_id, 0, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, 0, scores, nullptr);
curated_topster->add(&kv);
}
}
@ -2503,12 +2710,16 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
const vector_query_t& vector_query,
size_t facet_sample_percent, size_t facet_sample_threshold,
const std::string& collection_name) const {
std::shared_lock lock(mutex);
uint32_t filter_ids_length = 0;
auto rearrange_op = rearrange_filter_tree(filter_tree_root, filter_ids_length, collection_name);
if (!rearrange_op.ok()) {
return rearrange_op;
}
filter_result_t filter_result;
// process the filters
auto filter_op = adaptive_filter(filter_tree_root, filter_result, collection_name);
auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name);
if (!filter_op.ok()) {
return filter_op;
}
@ -2582,7 +2793,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int64_t match_score_index = -1;
result_ids.push_back(seq_id);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
@ -2681,7 +2893,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
@ -4615,7 +4827,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
filter_result_t result;
adaptive_filter(sort_fields_std[i].eval.filter_tree_root, result);
recursive_filter(sort_fields_std[i].eval.filter_tree_root, result);
sort_fields_std[i].eval.ids = result.docs;
sort_fields_std[i].eval.size = result.count;
result.docs = nullptr;

View File

@ -43,6 +43,61 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i
*ids = out;
}
void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) {
if (int64map.empty()) {
return;
}
auto it_start = int64map.lower_bound(start); // iter values will be >= start
while (it_start != int64map.end() && it_start->first <= end) {
uint32_t val_ids = ids_t::num_ids(it_start->second);
ids_len += val_ids;
it_start++;
}
}
bool num_tree_t::range_inclusive_contains(const int64_t& start, const int64_t& end, const uint32_t& id) const {
if (int64map.empty()) {
return false;
}
auto it_start = int64map.lower_bound(start); // iter values will be >= start
while (it_start != int64map.end() && it_start->first <= end) {
if (ids_t::contains(it_start->second, id)) {
return true;
}
}
return false;
}
void num_tree_t::range_inclusive_contains(const int64_t& start, const int64_t& end,
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const {
if (int64map.empty()) {
return;
}
std::vector<uint32_t> consolidated_ids;
consolidated_ids.reserve(context_ids_length);
for (uint32_t i = 0; i < context_ids_length; i++) {
if (range_inclusive_contains(start, end, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
}
}
uint32_t *out = nullptr;
result_ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
result_ids, result_ids_len, &out);
delete [] result_ids;
result_ids = out;
}
size_t num_tree_t::get(int64_t value, std::vector<uint32_t>& geo_result_ids) {
const auto& it = int64map.find(value);
if(it == int64map.end()) {
@ -132,6 +187,54 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids
}
}
void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) {
if (int64map.empty()) {
return;
}
if (comparator == EQUALS) {
const auto& it = int64map.find(value);
if (it != int64map.end()) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
}
} else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
if (iter_ge_value == int64map.end()) {
return;
}
if (comparator == GREATER_THAN && iter_ge_value->first == value) {
iter_ge_value++;
}
while (iter_ge_value != int64map.end()) {
uint32_t val_ids = ids_t::num_ids(iter_ge_value->second);
ids_len += val_ids;
iter_ge_value++;
}
} else if (comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
auto it = int64map.begin();
while (it != iter_ge_value) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
it++;
}
// for LESS_THAN_EQUALS, check if last iter entry is equal to value
if (it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) {
uint32_t val_ids = ids_t::num_ids(it->second);
ids_len += val_ids;
}
}
}
void num_tree_t::remove(uint64_t value, uint32_t id) {
if(int64map.count(value) != 0) {
void* arr = int64map[value];
@ -146,6 +249,75 @@ void num_tree_t::remove(uint64_t value, uint32_t id) {
}
}
void num_tree_t::contains(const NUM_COMPARATOR& comparator, const int64_t& value,
const uint32_t& context_ids_length,
uint32_t* const& context_ids,
size_t& result_ids_len,
uint32_t*& result_ids) const {
if (int64map.empty()) {
return;
}
std::vector<uint32_t> consolidated_ids;
consolidated_ids.reserve(context_ids_length);
for (uint32_t i = 0; i < context_ids_length; i++) {
if (comparator == EQUALS) {
if (contains(value, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
}
} else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
if (iter_ge_value == int64map.end()) {
continue;
}
if (comparator == GREATER_THAN && iter_ge_value->first == value) {
iter_ge_value++;
}
while (iter_ge_value != int64map.end()) {
if (contains(iter_ge_value->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
iter_ge_value++;
}
} else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) {
// iter entries will be >= value, or end() if all entries are before value
auto iter_ge_value = int64map.lower_bound(value);
auto it = int64map.begin();
while (it != iter_ge_value) {
if (contains(it->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
it++;
}
// for LESS_THAN_EQUALS, check if last iter entry is equal to value
if (it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) {
if (contains(it->first, context_ids[i])) {
consolidated_ids.push_back(context_ids[i]);
break;
}
}
}
}
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
uint32_t *out = nullptr;
result_ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
result_ids, result_ids_len, &out);
delete[] result_ids;
result_ids = out;
}
size_t num_tree_t::size() {
return int64map.size();
}

View File

@ -386,7 +386,32 @@ void posting_t::merge(const std::vector<void*>& raw_posting_lists, std::vector<u
}
}
void posting_t::intersect(const std::vector<void*>& raw_posting_lists, std::vector<uint32_t>& result_ids) {
void posting_t::intersect(const std::vector<void*>& raw_posting_lists, std::vector<uint32_t>& result_ids,
const uint32_t& context_ids_length,
const uint32_t* context_ids) {
if (context_ids_length != 0) {
if (raw_posting_lists.empty()) {
return;
}
for (uint32_t i = 0; i < context_ids_length; i++) {
bool is_present = true;
for (auto const& raw_posting_list: raw_posting_lists) {
if (!contains(raw_posting_list, context_ids[i])) {
is_present = false;
break;
}
}
if (is_present) {
result_ids.push_back(context_ids[i]);
}
}
return;
}
// we will have to convert the compact posting list (if any) to full form
std::vector<posting_list_t*> plists;
std::vector<posting_list_t*> expanded_plists;

View File

@ -551,6 +551,297 @@ TEST_F(CollectionJoinTest, FilterByReference_MultipleMatch) {
collectionManager.drop_collection("Links");
}
TEST_F(CollectionJoinTest, AndFilterResults_NoReference) {
filter_result_t a;
a.count = 9;
a.docs = new uint32_t[a.count];
for (size_t i = 0; i < a.count; i++) {
a.docs[i] = i;
}
filter_result_t b;
b.count = 0;
uint32_t limit = 10;
b.docs = new uint32_t[limit];
for (size_t i = 2; i < limit; i++) {
if (i % 3 == 0) {
b.docs[b.count++] = i;
}
}
// a.docs: [0..8] , b.docs: [3, 6, 9]
filter_result_t result;
filter_result_t::and_filter_results(a, b, result);
ASSERT_EQ(2, result.count);
ASSERT_EQ(0, result.reference_filter_results.size());
std::vector<uint32_t> docs = {3, 6};
for(size_t i = 0; i < result.count; i++) {
ASSERT_EQ(docs[i], result.docs[i]);
}
}
TEST_F(CollectionJoinTest, AndFilterResults_WithReferences) {
filter_result_t a;
a.count = 9;
a.docs = new uint32_t[a.count];
a.reference_filter_results["foo"] = new reference_filter_result_t[a.count];
for (size_t i = 0; i < a.count; i++) {
a.docs[i] = i;
auto& reference = a.reference_filter_results["foo"][i];
reference.count = 1;
reference.docs = new uint32_t[1];
reference.docs[0] = 10 - i;
}
filter_result_t b;
b.count = 0;
uint32_t limit = 10;
b.docs = new uint32_t[limit];
b.reference_filter_results["bar"] = new reference_filter_result_t[limit];
for (size_t i = 2; i < limit; i++) {
if (i % 3 == 0) {
b.docs[b.count] = i;
auto& reference = b.reference_filter_results["bar"][b.count++];
reference.count = 1;
reference.docs = new uint32_t[1];
reference.docs[0] = 2 * i;
}
}
// a.docs: [0..8] , b.docs: [3, 6, 9]
filter_result_t result;
filter_result_t::and_filter_results(a, b, result);
ASSERT_EQ(2, result.count);
ASSERT_EQ(2, result.reference_filter_results.size());
ASSERT_EQ(1, result.reference_filter_results.count("foo"));
ASSERT_EQ(1, result.reference_filter_results.count("bar"));
std::vector<uint32_t> docs = {3, 6}, foo_reference = {7, 4}, bar_reference = {6, 12};
for(size_t i = 0; i < result.count; i++) {
ASSERT_EQ(docs[i], result.docs[i]);
ASSERT_EQ(1, result.reference_filter_results["foo"][i].count);
ASSERT_EQ(foo_reference[i], result.reference_filter_results["foo"][i].docs[0]);
ASSERT_EQ(1, result.reference_filter_results["bar"][i].count);
ASSERT_EQ(bar_reference[i], result.reference_filter_results["bar"][i].docs[0]);
}
}
TEST_F(CollectionJoinTest, FilterByNReferences) {
auto schema_json =
R"({
"name": "Users",
"fields": [
{"name": "user_id", "type": "string"},
{"name": "user_name", "type": "string"}
]
})"_json;
std::vector<nlohmann::json> documents = {
R"({
"user_id": "user_a",
"user_name": "Roshan"
})"_json,
R"({
"user_id": "user_b",
"user_name": "Ruby"
})"_json,
R"({
"user_id": "user_c",
"user_name": "Joe"
})"_json,
R"({
"user_id": "user_d",
"user_name": "Aby"
})"_json
};
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Repos",
"fields": [
{"name": "repo_id", "type": "string"},
{"name": "repo_content", "type": "string"},
{"name": "repo_stars", "type": "int32"},
{"name": "repo_is_private", "type": "bool"}
]
})"_json;
documents = {
R"({
"repo_id": "repo_a",
"repo_content": "body1",
"repo_stars": 431,
"repo_is_private": true
})"_json,
R"({
"repo_id": "repo_b",
"repo_content": "body2",
"repo_stars": 4562,
"repo_is_private": false
})"_json,
R"({
"repo_id": "repo_c",
"repo_content": "body3",
"repo_stars": 945,
"repo_is_private": false
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Links",
"fields": [
{"name": "repo_id", "type": "string", "reference": "Repos.repo_id"},
{"name": "user_id", "type": "string", "reference": "Users.user_id"}
]
})"_json;
documents = {
R"({
"repo_id": "repo_a",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_a",
"user_id": "user_c"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_a"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_d"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_a"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_c"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_d"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Organizations",
"fields": [
{"name": "org_id", "type": "string"},
{"name": "org_name", "type": "string"}
]
})"_json;
documents = {
R"({
"org_id": "org_a",
"org_name": "Typesense"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Participants",
"fields": [
{"name": "user_id", "type": "string", "reference": "Users.user_id"},
{"name": "org_id", "type": "string", "reference": "Organizations.org_id"}
]
})"_json;
documents = {
R"({
"user_id": "user_a",
"org_id": "org_a"
})"_json,
R"({
"user_id": "user_b",
"org_id": "org_a"
})"_json,
R"({
"user_id": "user_d",
"org_id": "org_a"
})"_json,
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
auto coll = collectionManager.get_collection_unsafe("Users");
// Search for users within an organization with access to a particular repo.
auto result = coll->search("R", {"user_name"}, "$Participants(org_id:=org_a) && $Links(repo_id:=repo_b)", {}, {}, {0},
10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get();
ASSERT_EQ(2, result["found"].get<size_t>());
ASSERT_EQ(2, result["hits"].size());
ASSERT_EQ("user_b", result["hits"][0]["document"]["user_id"].get<std::string>());
ASSERT_EQ("user_a", result["hits"][1]["document"]["user_id"].get<std::string>());
collectionManager.drop_collection("Users");
collectionManager.drop_collection("Repos");
collectionManager.drop_collection("Links");
}
TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) {
auto schema_json =
R"({
@ -651,11 +942,11 @@ TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) {
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error());
req_params["include_fields"] = "$foo(bar)";
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Referenced collection `foo` not found.", search_op.error());
// req_params["include_fields"] = "$foo(bar)";
// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
// ASSERT_FALSE(search_op.ok());
// ASSERT_EQ("Referenced collection `foo` not found.", search_op.error());
//
// req_params["include_fields"] = "$Customers(bar)";
// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
// ASSERT_TRUE(search_op.ok());
@ -709,4 +1000,4 @@ TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) {
// // 3 fields in Products document and 2 fields from Customers document
// ASSERT_EQ(5, res_obj["hits"][0]["document"].size());
// ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id_sequence_id"));
}
}

View File

@ -1973,3 +1973,235 @@ TEST_F(CollectionSpecificMoreTest, CrossFieldTypoAndPrefixWithWeights) {
"<mark>", "</mark>", {2, 3}).get();
ASSERT_EQ(1, res["hits"].size());
}
TEST_F(CollectionSpecificMoreTest, RearrangingFilterTree) {
nlohmann::json schema =
R"({
"name": "Collection",
"fields": [
{"name": "name", "type": "string"},
{"name": "age", "type": "int32"},
{"name": "years", "type": "int32[]"},
{"name": "rating", "type": "float"}
]
})"_json;
Collection* coll = collectionManager.create_collection(schema).get();
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::string json_line;
while (std::getline(infile, json_line)) {
auto add_op = coll->add(json_line);
ASSERT_TRUE(add_op.ok());
}
infile.close();
const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query("years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))",
coll->get_schema(), store, doc_id_prefix, filter_tree_root);
ASSERT_TRUE(filter_op.ok());
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
// &&
// / \
// years>2000 ||
// 4 / \
// / &&
// && / \
// / \ age>50 rating<5
// / \ 1 2
// / \
// age<30 rating>5
// 2 3
ASSERT_TRUE(filter_tree_root != nullptr);
ASSERT_TRUE(filter_tree_root->isOperator);
ASSERT_EQ(filter_tree_root->filter_operator, AND);
auto root = filter_tree_root->left;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "years");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->right;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, OR);
root = filter_tree_root->right->left;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, AND);
root = filter_tree_root->right->left->left;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "age");
ASSERT_EQ(root->filter_exp.comparators.front(), LESS_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "30");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->right->left->right;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "rating");
ASSERT_EQ(root->filter_exp.comparators.front(), GREATER_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "5");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->right->right;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, AND);
root = filter_tree_root->right->right->left;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "age");
ASSERT_EQ(root->filter_exp.comparators.front(), GREATER_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "50");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->right->right->right;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "rating");
ASSERT_EQ(root->filter_exp.comparators.front(), LESS_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "5");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
uint32_t count = 0;
coll->_get_index()->rearrange_filter_tree(filter_tree_root, count);
// &&
// / \
// || years>2000
// / \
// && \
// / \ \
// age>50 rating<5 &&
// / \
// age<30 rating>5
ASSERT_TRUE(filter_tree_root != nullptr);
ASSERT_TRUE(filter_tree_root->isOperator);
ASSERT_EQ(filter_tree_root->filter_operator, AND);
root = filter_tree_root->left;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, OR);
root = filter_tree_root->left->left;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, AND);
root = filter_tree_root->left->left->left;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "age");
ASSERT_EQ(root->filter_exp.comparators.front(), GREATER_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "50");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->left->left->right;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "rating");
ASSERT_EQ(root->filter_exp.comparators.front(), LESS_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "5");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->left->right;
ASSERT_TRUE(root != nullptr);
ASSERT_TRUE(root->isOperator);
ASSERT_EQ(root->filter_operator, AND);
root = filter_tree_root->left->right->left;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "age");
ASSERT_EQ(root->filter_exp.comparators.front(), LESS_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "30");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->left->right->right;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "rating");
ASSERT_EQ(root->filter_exp.comparators.front(), GREATER_THAN);
ASSERT_EQ(root->filter_exp.values.front(), "5");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
root = filter_tree_root->right;
ASSERT_TRUE(root != nullptr);
ASSERT_FALSE(root->isOperator);
ASSERT_EQ(root->filter_exp.field_name, "years");
ASSERT_TRUE(root->left == nullptr);
ASSERT_TRUE(root->right == nullptr);
collectionManager.drop_collection("Collection");
}
TEST_F(CollectionSpecificMoreTest, ApproxFilterMatchCount) {
nlohmann::json schema =
R"({
"name": "Collection",
"fields": [
{"name": "name", "type": "string"},
{"name": "age", "type": "int32"},
{"name": "years", "type": "int32[]"},
{"name": "rating", "type": "float"},
{"name": "location", "type": "geopoint", "optional": true}
]
})"_json;
Collection *coll = collectionManager.create_collection(schema).get();
std::ifstream infile(std::string(ROOT_DIR) + "test/numeric_array_documents.jsonl");
std::string json_line;
while (std::getline(infile, json_line)) {
auto add_op = coll->add(json_line);
ASSERT_TRUE(add_op.ok());
}
infile.close();
const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query("name: Jeremy", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
uint32_t approx_count;
coll->_get_index()->_approximate_filter_ids(filter_tree_root->filter_exp, approx_count);
ASSERT_EQ(approx_count, 5);
delete filter_tree_root;
filter_op = filter::parse_filter_query("location:(48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469)",
coll->get_schema(), store, doc_id_prefix, filter_tree_root);
ASSERT_TRUE(filter_op.ok());
coll->_get_index()->_approximate_filter_ids(filter_tree_root->filter_exp, approx_count);
ASSERT_EQ(approx_count, 100);
delete filter_tree_root;
filter_op = filter::parse_filter_query("years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))",
coll->get_schema(), store, doc_id_prefix, filter_tree_root);
ASSERT_TRUE(filter_op.ok());
coll->_get_index()->rearrange_filter_tree(filter_tree_root, approx_count);
ASSERT_EQ(approx_count, 3);
delete filter_tree_root;
collectionManager.drop_collection("Collection");
}