diff --git a/include/collection.h b/include/collection.h index 720631c0..71b879b1 100644 --- a/include/collection.h +++ b/include/collection.h @@ -436,6 +436,8 @@ public: Option get_filter_ids(const std::string & filter_query, std::vector>& index_ids) const; + Option validate_reference_filter(const std::string& filter_query) const; + Option get(const std::string & id) const; Option remove(const std::string & id, bool remove_from_store = true); diff --git a/include/field.h b/include/field.h index ca29fa43..ee69bff5 100644 --- a/include/field.h +++ b/include/field.h @@ -77,7 +77,7 @@ struct field { static constexpr int VAL_UNKNOWN = 2; - std::string reference; // Reference to another collection. + std::string reference; // Foo.bar (reference to bar field in Foo collection). field() {} @@ -448,6 +448,9 @@ struct filter { // aggregated and then this flag is checked if negation on the aggregated result is required. bool apply_not_equals = false; + // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` + std::string referenced_collection_name; + static const std::string RANGE_OPERATOR() { return ".."; } diff --git a/src/collection.cpp b/src/collection.cpp index da1f0e5a..516a8bf3 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2362,6 +2362,22 @@ Option Collection::get_filter_ids(const std::string & filter_query, return Option(true); } +Option Collection::validate_reference_filter(const std::string& filter_query) const { + std::shared_lock lock(mutex); + + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query(filter_query, search_schema, + store, doc_id_prefix, filter_tree_root); + + if(!filter_op.ok()) { + return filter_op; + } + + delete filter_tree_root; + return Option(true); +} + bool Collection::facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document, std::string &value) const { diff --git a/src/field.cpp b/src/field.cpp index 4275577a..f0bd4eaa 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -2,6 +2,7 @@ #include "field.h" #include "magic_enum.hpp" #include +#include Option filter::parse_geopoint_filter_value(std::string& raw_value, const std::string& format_err_msg, @@ -408,9 +409,32 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); } else { filter filter_exp; - Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); - if (!toFilter_op.ok()) { - return toFilter_op; + + // Expected value: $Collection(...) + bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); + if (is_referenced_filter) { + size_t parenthesis_index = expression.find('('); + + std::string collection_name = expression.substr(1, parenthesis_index - 1); + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(collection_name); + if (collection == nullptr) { + return Option(400, "Referenced collection `" + collection_name + "` not found."); + } + + filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; + filter_exp.referenced_collection_name = collection_name; + + auto op = collection->validate_reference_filter(filter_exp.field_name); + if (!op.ok()) { + return Option(400, "Failed to parse reference filter on `" + collection_name + + "` collection: " + op.error()); + } + } else { + Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); + if (!toFilter_op.ok()) { + return toFilter_op; + } } filter_node = new filter_node_t(filter_exp); diff --git a/src/index.cpp b/src/index.cpp index c9f79b38..e21dddcc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -470,8 +470,7 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); } - document[a_field.name + "_sequence_id"] = collection->get_seq_id_collection_prefix() + "_" + - StringUtils::serialize_uint32_t(*(documents[0].second)); + document[a_field.name + "_sequence_id"] = StringUtils::serialize_uint32_t(*(documents[0].second)); delete [] documents[0].second; } @@ -1668,6 +1667,57 @@ void Index::do_filtering(uint32_t*& filter_ids, // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + + std::vector> documents; + auto op = collection->get_filter_ids(a_filter.field_name, documents); + if (!op.ok()) { + return; + } + + if (documents[0].first > 0) { + const field* reference_field = nullptr; + for (auto const& f: collection->get_fields()) { + auto this_collection_name = cm.get_collection_with_id(collection_id)->get_name(); + if (!f.reference.empty() && + f.reference.find(this_collection_name) == 0 && + f.reference.find('.') == this_collection_name.size()) { + reference_field = &f; + break; + } + } + + if (reference_field == nullptr) { + return; + } + + std::vector result_ids; + for (size_t i = 0; i < documents[0].first; i++) { + uint32_t seq_id = *(documents[0].second + i); + + nlohmann::json document; + auto op = collection->get_document_from_store(seq_id, document); + if (!op.ok()) { + return; + } + + result_ids.push_back(StringUtils::deserialize_uint32_t(document[reference_field->name + "_sequence_id"].get())); + } + + filter_ids = new uint32[result_ids.size()]; + std::sort(result_ids.begin(), result_ids.end()); + std::copy(result_ids.begin(), result_ids.end(), filter_ids); + filter_ids_length = result_ids.size(); + } + + delete [] documents[0].second; + return; + } + if (a_filter.field_name == "id") { // we handle `ids` separately std::vector result_ids; diff --git a/src/string_utils.cpp b/src/string_utils.cpp index e7893004..a9409400 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -349,9 +349,40 @@ size_t StringUtils::get_num_chars(const std::string& s) { return j; } +Option parse_reference_filter(const std::string& filter_query, std::queue& tokens, size_t& index) { + auto error = Option(400, "Could not parse the reference filter."); + if (filter_query[index] != '$') { + return error; + } + + int start_index = index; + auto size = filter_query.size(); + while(++index < size && filter_query[index] != '(') {} + + if (index >= size) { + return error; + } + + int parenthesis_count = 1; + while (++index < size && parenthesis_count > 0) { + if (filter_query[index] == '(') { + parenthesis_count++; + } else if (filter_query[index] == ')') { + parenthesis_count--; + } + } + + if (parenthesis_count != 0) { + return error; + } + + tokens.push(filter_query.substr(start_index, index - start_index)); + return Option(true); +} + Option StringUtils::tokenize_filter_query(const std::string& filter_query, std::queue& tokens) { auto size = filter_query.size(); - for (auto i = 0; i < size;) { + for (size_t i = 0; i < size;) { auto c = filter_query[i]; if (c == ' ') { i++; @@ -377,6 +408,15 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, tokens.push("||"); i += 2; } else { + // Reference filter would start with $ symbol. + if (c == '$') { + auto op = parse_reference_filter(filter_query, tokens, i); + if (!op.ok()) { + return op; + } + continue; + } + std::stringstream ss; bool inBacktick = false; bool preceding_colon = false; diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 9964c864..26a7d476 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -93,14 +93,14 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { ASSERT_EQ(schema.at("customer_name").reference, ""); ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); - // Index a `foo_sequence_id` field for `foo` reference field. + // Add a `foo_sequence_id` field in the schema for `foo` reference field. ASSERT_EQ(schema.count("product_id_sequence_id"), 1); ASSERT_TRUE(schema.at("product_id_sequence_id").index); collectionManager.drop_collection("Customers"); } -TEST_F(CollectionJoinTest, IndexReferenceField) { +TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { auto products_schema_json = R"({ "name": "Products", @@ -272,12 +272,123 @@ TEST_F(CollectionJoinTest, IndexReferenceField) { ASSERT_TRUE(add_op.ok()); ASSERT_EQ(customer_collection->get("0").get().count("product_id_sequence_id"), 1); + // Referenced document should be accessible from Customers collection. + auto sequence_id = collectionManager.get_collection("Products")->get_seq_id_collection_prefix() + "_" + + customer_collection->get("0").get()["product_id_sequence_id"].get(); nlohmann::json document; - auto get_op = customer_collection->get_document_from_store(customer_collection->get("0").get()["product_id_sequence_id"].get(), document); + auto get_op = customer_collection->get_document_from_store(sequence_id, document); ASSERT_TRUE(get_op.ok()); ASSERT_EQ(document.count("product_id"), 1); ASSERT_EQ(document["product_id"], "product_a"); + ASSERT_EQ(document["product_name"], "shampoo"); collectionManager.drop_collection("Customers"); collectionManager.drop_collection("Products"); +} + +TEST_F(CollectionJoinTest, FilterByReferenceField) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_b", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + documents = { + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 143, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 73.5, + "product_id": "product_b" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 75, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 140, + "product_id": "product_b" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + auto coll = collectionManager.get_collection("Products"); + auto search_op = coll->search("s", {"product_name"}, "$foo:=customer_a", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Could not parse the reference filter."); + + search_op = coll->search("s", {"product_name"}, "$foo(:=customer_a", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Could not parse the reference filter."); + + search_op = coll->search("s", {"product_name"}, "$foo(:=customer_a)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Referenced collection `foo` not found."); + + search_op = coll->search("s", {"product_name"}, "$Customers(foo:=customer_a)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Failed to parse reference filter on `Customers` collection: Could not find a filter field named `foo` in the schema."); + + auto result = coll->search("s", {"product_name"}, "$Customers(customer_id:=customer_a && product_price:<100)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get(); + + ASSERT_EQ(1, result["found"].get()); + ASSERT_EQ(1, result["hits"].size()); + ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); } \ No newline at end of file diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index 1d2e0246..31acfafb 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -393,4 +393,8 @@ TEST(StringUtilsTest, TokenizeFilterQuery) { filter_query = "((age:<5||age:>10)&&location:(48.906,2.343,5mi))||tags:AT&T"; tokenList = {"(", "(", "age:<5", "||", "age:>10", ")", "&&", "location:(48.906,2.343,5mi)", ")", "||", "tags:AT&T"}; tokenizeTestHelper(filter_query, tokenList); + + filter_query = "((age: <5 || age: >10) && category:= [shoes]) && $Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + tokenList = {"(", "(", "age: <5", "||", "age: >10", ")", "&&", "category:= [shoes]", ")", "&&", "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"}; + tokenizeTestHelper(filter_query, tokenList); }