diff --git a/include/field.h b/include/field.h index a4ac81b8..87a17702 100644 --- a/include/field.h +++ b/include/field.h @@ -660,6 +660,67 @@ struct filter_result_t { delete[] item.second; } } + + static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { + auto lenA = a.count, lenB = b.count; + if (lenA == 0 || lenB == 0) { + return; + } + + result.docs = new uint32_t[std::min(lenA, lenB)]; + + auto A = a.docs, B = b.docs, out = result.docs; + const uint32_t *endA = A + lenA; + const uint32_t *endB = B + lenB; + + for (auto const& item: a.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + for (auto const& item: b.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + + while (true) { + while (*A < *B) { + SKIP_FIRST_COMPARE: + if (++A == endA) { + result.count = out - result.docs; + return; + } + } + while (*A > *B) { + if (++B == endB) { + result.count = out - result.docs; + return; + } + } + if (*A == *B) { + *out = *A; + + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; + item.second[A - a.docs].docs = nullptr; + } + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; + item.second[B - b.docs].docs = nullptr; + } + + out++; + + if (++A == endA || ++B == endB) { + result.count = out - result.docs; + return; + } + } else { + goto SKIP_FIRST_COMPARE; + } + } + } }; namespace sort_field_const { diff --git a/src/index.cpp b/src/index.cpp index 3f6809b7..0e09cc8b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2095,67 +2095,6 @@ Option Index::rearrange_filter_tree(filter_node_t* const root, return Option(true); } -void and_filter_result(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { - auto lenA = a.count, lenB = b.count; - if (lenA == 0 || lenB == 0) { - return; - } - - result.docs = new uint32_t[std::min(lenA, lenB)]; - - auto A = a.docs, B = b.docs, out = result.docs; - const uint32_t *endA = A + lenA; - const uint32_t *endB = B + lenB; - - for (auto const& item: a.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - for (auto const& item: b.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - - while (true) { - while (*A < *B) { - SKIP_FIRST_COMPARE: - if (++A == endA) { - result.count = out - result.docs; - return; - } - } - while (*A > *B) { - if (++B == endB) { - result.count = out - result.docs; - return; - } - } - if (*A == *B) { - *out = *A; - - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; - item.second[A - a.docs].docs = nullptr; - } - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; - item.second[B - b.docs].docs = nullptr; - } - - out++; - - if (++A == endA || ++B == endB) { - result.count = out - result.docs; - return; - } - } else { - goto SKIP_FIRST_COMPARE; - } - } -} - void copy_reference_ids(filter_result_t& from, filter_result_t& to) { if (to.count > 0 && !from.reference_filter_results.empty()) { for (const auto &item: from.reference_filter_results) { @@ -2206,7 +2145,7 @@ Option Index::recursive_filter(filter_node_t* const root, } if (root->filter_operator == AND) { - and_filter_result(l_result, r_result, result); + filter_result_t::and_filter_results(l_result, r_result, result); } else { uint32_t* filtered_results = nullptr; result.count = ArrayUtils::or_scalar( diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 6df9d397..23c4b022 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -551,6 +551,89 @@ TEST_F(CollectionJoinTest, FilterByReference_MultipleMatch) { collectionManager.drop_collection("Links"); } +TEST_F(CollectionJoinTest, AndFilterResults_NoReference) { + filter_result_t a; + a.count = 9; + a.docs = new uint32_t[a.count]; + for (size_t i = 0; i < a.count; i++) { + a.docs[i] = i; + } + + filter_result_t b; + b.count = 0; + uint32_t limit = 10; + b.docs = new uint32_t[limit]; + for (size_t i = 2; i < limit; i++) { + if (i % 3 == 0) { + b.docs[b.count++] = i; + } + } + + // a.docs: [0..8] , b.docs: [3, 6, 9] + filter_result_t result; + filter_result_t::and_filter_results(a, b, result); + + ASSERT_EQ(2, result.count); + ASSERT_EQ(0, result.reference_filter_results.size()); + + std::vector docs = {3, 6}; + + for(size_t i = 0; i < result.count; i++) { + ASSERT_EQ(docs[i], result.docs[i]); + } +} + +TEST_F(CollectionJoinTest, AndFilterResults_WithReferences) { + filter_result_t a; + a.count = 9; + a.docs = new uint32_t[a.count]; + a.reference_filter_results["foo"] = new reference_filter_result_t[a.count]; + for (size_t i = 0; i < a.count; i++) { + a.docs[i] = i; + + auto& reference = a.reference_filter_results["foo"][i]; + reference.count = 1; + reference.docs = new uint32_t[1]; + reference.docs[0] = 10 - i; + } + + filter_result_t b; + b.count = 0; + uint32_t limit = 10; + b.docs = new uint32_t[limit]; + b.reference_filter_results["bar"] = new reference_filter_result_t[limit]; + for (size_t i = 2; i < limit; i++) { + if (i % 3 == 0) { + b.docs[b.count] = i; + + auto& reference = b.reference_filter_results["bar"][b.count++]; + reference.count = 1; + reference.docs = new uint32_t[1]; + reference.docs[0] = 2 * i; + } + } + + // a.docs: [0..8] , b.docs: [3, 6, 9] + filter_result_t result; + filter_result_t::and_filter_results(a, b, result); + + ASSERT_EQ(2, result.count); + ASSERT_EQ(2, result.reference_filter_results.size()); + ASSERT_EQ(1, result.reference_filter_results.count("foo")); + ASSERT_EQ(1, result.reference_filter_results.count("bar")); + + std::vector docs = {3, 6}, foo_reference = {7, 4}, bar_reference = {6, 12}; + + for(size_t i = 0; i < result.count; i++) { + ASSERT_EQ(docs[i], result.docs[i]); + + ASSERT_EQ(1, result.reference_filter_results["foo"][i].count); + ASSERT_EQ(foo_reference[i], result.reference_filter_results["foo"][i].docs[0]); + ASSERT_EQ(1, result.reference_filter_results["bar"][i].count); + ASSERT_EQ(bar_reference[i], result.reference_filter_results["bar"][i].docs[0]); + } +} + TEST_F(CollectionJoinTest, FilterByNReferences) { auto schema_json = R"({