From 56d382d1e0c5ed3197f373f48c18184bee30cfeb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 16 Jan 2023 13:39:17 +0530 Subject: [PATCH 01/27] Add reference in field struct. --- include/field.h | 7 ++- src/field.cpp | 17 ++++++- test/collection_join_test.cpp | 96 +++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 test/collection_join_test.cpp diff --git a/include/field.h b/include/field.h index 242f8f9b..d7960716 100644 --- a/include/field.h +++ b/include/field.h @@ -47,6 +47,7 @@ namespace fields { static const std::string nested_array = "nested_array"; static const std::string num_dim = "num_dim"; static const std::string vec_dist = "vec_dist"; + static const std::string reference = "reference"; } enum vector_distance_type_t { @@ -76,13 +77,15 @@ struct field { static constexpr int VAL_UNKNOWN = 2; + std::string reference; // Reference to another collection. + field() {} field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine) : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "") : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference) { set_computed_defaults(sort, infix); } diff --git a/src/field.cpp b/src/field.cpp index c9df953f..f3a407e3 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -525,6 +525,12 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } } + if (field_json.count(fields::reference) != 0 && !field_json.at(fields::reference).is_string()) { + return Option(400, "Reference should be a string."); + } else if (field_json.count(fields::reference) == 0) { + field_json[fields::reference] = ""; + } + if(field_json["name"] == ".*") { if(field_json.count(fields::facet) == 0) { field_json[fields::facet] = false; @@ -562,6 +568,10 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso return Option(400, "Field `.*` must be an index field."); } + if (field_json.count(fields::reference) != 0) { + return Option(400, "Field `.*` cannot be a reference field."); + } + field fallback_field(field_json["name"], field_json["type"], field_json["facet"], field_json["optional"], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix]); @@ -659,6 +669,10 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso bool is_obj = field_json[fields::type] == field_types::OBJECT || field_json[fields::type] == field_types::OBJECT_ARRAY; bool is_regexp_name = field_json[fields::name].get().find(".*") != std::string::npos; + if (is_regexp_name && field_json.count(fields::reference) != 0) { + return Option(400, "Wildcard field cannot have a reference."); + } + if(is_obj || (!is_regexp_name && enable_nested_fields && field_json[fields::name].get().find('.') != std::string::npos)) { field_json[fields::nested] = true; @@ -679,7 +693,8 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field(field_json[fields::name], field_json[fields::type], field_json[fields::facet], field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], - field_json[fields::nested_array], field_json[fields::num_dim], vec_dist) + field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, + field_json[fields::reference]) ); return Option(true); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp new file mode 100644 index 00000000..0065e8ae --- /dev/null +++ b/test/collection_join_test.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class CollectionJoinTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_join"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(CollectionJoinTest, SchemaReferenceField) { + nlohmann::json schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "product_.*", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_FALSE(collection_create_op.ok()); + ASSERT_EQ("Wildcard field cannot have a reference.", collection_create_op.error()); + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": ".*", "type": "auto", "reference": "Products.product_id"} + ] + })"_json; + + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_FALSE(collection_create_op.ok()); + ASSERT_EQ("Field `.*` cannot be a reference field.", collection_create_op.error()); + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "product_id", "type": "string", "reference": 123}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"} + ] + })"_json; + + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_FALSE(collection_create_op.ok()); + ASSERT_EQ("Reference should be a string.", collection_create_op.error()); + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "product_id", "type": "string", "reference": "Products.product_id"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"} + ] + })"_json; + + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + auto collection = collection_create_op.get(); + auto schema = collection->get_schema(); + + ASSERT_EQ(schema.at("customer_name").reference, ""); + ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); + collectionManager.drop_collection("Customers"); +} \ No newline at end of file From bb1c72a8f40fea394f86384bfb1ad69a4a46927a Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 16 Jan 2023 19:21:17 +0530 Subject: [PATCH 02/27] Persist `reference` field property. --- include/field.h | 4 ++++ src/collection.cpp | 4 ++++ src/collection_manager.cpp | 6 +++++- src/field.cpp | 4 ++-- test/collection_manager_test.cpp | 19 +++++++++++++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/include/field.h b/include/field.h index d7960716..ca29fa43 100644 --- a/include/field.h +++ b/include/field.h @@ -310,6 +310,10 @@ struct field { field_val[fields::vec_dist] = field.vec_dist == ip ? "ip" : "cosine"; } + if (!field.reference.empty()) { + field_val[fields::reference] = field.reference; + } + fields_json.push_back(field_val); if(!field.has_valid_type()) { diff --git a/src/collection.cpp b/src/collection.cpp index e2d2eb4e..ab162605 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -180,6 +180,10 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::num_dim] = coll_field.num_dim; } + if (!coll_field.reference.empty()) { + field_json[fields::reference] = coll_field.reference; + } + fields_arr.push_back(field_json); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index b3f8a2a3..0fb4d5be 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -54,6 +54,10 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field_obj[fields::num_dim] = 0; } + if (field_obj.count(fields::reference) == 0) { + field_obj[fields::reference] = ""; + } + vector_distance_type_t vec_dist_type = vector_distance_type_t::cosine; if(field_obj.count(fields::vec_dist) != 0) { @@ -66,7 +70,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet], field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale], -1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array], - field_obj[fields::num_dim], vec_dist_type); + field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference]); // value of `sort` depends on field type if(field_obj.count(fields::sort) == 0) { diff --git a/src/field.cpp b/src/field.cpp index f3a407e3..c184d8ab 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -568,7 +568,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso return Option(400, "Field `.*` must be an index field."); } - if (field_json.count(fields::reference) != 0) { + if (!field_json[fields::reference].get().empty()) { return Option(400, "Field `.*` cannot be a reference field."); } @@ -669,7 +669,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso bool is_obj = field_json[fields::type] == field_types::OBJECT || field_json[fields::type] == field_types::OBJECT_ARRAY; bool is_regexp_name = field_json[fields::name].get().find(".*") != std::string::npos; - if (is_regexp_name && field_json.count(fields::reference) != 0) { + if (is_regexp_name && !field_json[fields::reference].get().empty()) { return Option(400, "Wildcard field cannot have a reference."); } diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 75a1c77c..dc835fe2 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -36,7 +36,8 @@ protected: {"name": "not_stored", "type": "string", "optional": true, "index": false}, {"name": "points", "type": "int32"}, {"name": "person", "type": "object", "optional": true}, - {"name": "vec", "type": "float[]", "num_dim": 128, "optional": true} + {"name": "vec", "type": "float[]", "num_dim": 128, "optional": true}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} ], "default_sorting_field": "points", "symbols_to_index":["+"], @@ -44,7 +45,9 @@ protected: })"_json; sort_fields = { sort_by("points", "DESC") }; - collection1 = collectionManager.create_collection(schema).get(); + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + collection1 = op.get(); } virtual void SetUp() { @@ -210,6 +213,18 @@ TEST_F(CollectionManagerTest, CollectionCreation) { "sort":false, "type":"float[]", "vec_dist":"cosine" + }, + { + "facet":false, + "index":true, + "infix":false, + "locale":"", + "name":"product_id", + "nested":false, + "optional":false, + "sort":false, + "type":"string", + "reference":"Products.product_id" } ], "id":0, From aad34c8c5c5bc8321c27bab826ca2c7fc49f4e60 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 17 Jan 2023 09:19:08 +0530 Subject: [PATCH 03/27] Fix CollectionManagerTest.RestoreRecordsOnRestart --- test/collection_manager_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index dc835fe2..882f92fd 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -37,7 +37,7 @@ protected: {"name": "points", "type": "int32"}, {"name": "person", "type": "object", "optional": true}, {"name": "vec", "type": "float[]", "num_dim": 128, "optional": true}, - {"name": "product_id", "type": "string", "reference": "Products.product_id"} + {"name": "product_id", "type": "string", "reference": "Products.product_id", "optional": true} ], "default_sorting_field": "points", "symbols_to_index":["+"], From e16b4b7349201a0efa6d5b576a5c498d267c3c03 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 17 Jan 2023 09:49:17 +0530 Subject: [PATCH 04/27] Index `foo_sequence_id` field that stores sequence_id of referenced document. --- src/field.cpp | 7 +++++++ test/collection_join_test.cpp | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/field.cpp b/src/field.cpp index c184d8ab..a48c4637 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -697,6 +697,13 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::reference]) ); + if (!field_json[fields::reference].get().empty()) { + the_fields.emplace_back( + field(field_json[fields::name].get() + "_sequence_id", "string", false, + false, true) + ); + } + return Option(true); } diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 0065e8ae..81a311bb 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -92,5 +92,10 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { ASSERT_EQ(schema.at("customer_name").reference, ""); ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); + ASSERT_EQ(schema.count("product_id_sequence_id"), 1); + + auto field = schema.at("product_id_sequence_id"); + ASSERT_TRUE(field.index); + collectionManager.drop_collection("Customers"); } \ No newline at end of file From b41db06b1aba34e86015121bce3feed68ca6c9f3 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 17 Jan 2023 14:08:39 +0530 Subject: [PATCH 05/27] Abstract `foo_sequence_id` field from user. --- include/field.h | 7 +++++++ src/collection_manager.cpp | 4 ++++ src/field.cpp | 2 +- test/collection_join_test.cpp | 6 +++--- test/collection_manager_test.cpp | 10 ++++++++-- 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/include/field.h b/include/field.h index ca29fa43..0b6307a2 100644 --- a/include/field.h +++ b/include/field.h @@ -9,6 +9,7 @@ #include #include #include "json.hpp" +#include namespace field_types { // first field value indexed will determine the type @@ -278,11 +279,17 @@ struct field { const std::string & default_sorting_field, nlohmann::json& fields_json) { bool found_default_sorting_field = false; + const std::regex sequence_id_pattern(".*_sequence_id$"); // Check for duplicates in field names std::map> unique_fields; for(const field & field: fields) { + if (std::regex_match(field.name, sequence_id_pattern)) { + // Don't add foo_sequence_id field. + continue; + } + unique_fields[field.name].push_back(&field); if(field.name == "id") { diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 0fb4d5be..f1540932 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -80,6 +80,10 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection } fields.push_back(f); + + if (!f.reference.empty()) { + fields.emplace_back(field(f.name + "_sequence_id", "string", false, f.optional, true)); + } } std::string default_sorting_field = collection_meta[Collection::COLLECTION_DEFAULT_SORTING_FIELD_KEY].get(); diff --git a/src/field.cpp b/src/field.cpp index a48c4637..4275577a 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -700,7 +700,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso if (!field_json[fields::reference].get().empty()) { the_fields.emplace_back( field(field_json[fields::name].get() + "_sequence_id", "string", false, - false, true) + field_json[fields::optional], true) ); } diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 81a311bb..f57895d4 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -92,10 +92,10 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { ASSERT_EQ(schema.at("customer_name").reference, ""); ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); - ASSERT_EQ(schema.count("product_id_sequence_id"), 1); - auto field = schema.at("product_id_sequence_id"); - ASSERT_TRUE(field.index); + // Index a `foo_sequence_id` field for `foo` reference field. + ASSERT_EQ(schema.count("product_id_sequence_id"), 1); + ASSERT_TRUE(schema.at("product_id_sequence_id").index); collectionManager.drop_collection("Customers"); } \ No newline at end of file diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 882f92fd..dd4debd2 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -221,7 +221,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) { "locale":"", "name":"product_id", "nested":false, - "optional":false, + "optional":true, "sort":false, "type":"string", "reference":"Products.product_id" @@ -351,7 +351,11 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) { std::string json_line; while (std::getline(infile, json_line)) { - collection1->add(json_line); + auto op = collection1->add(json_line); + if (!op.ok()) { + LOG(INFO) << op.error(); + } + ASSERT_TRUE(op.ok()); } infile.close(); @@ -434,6 +438,7 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) { ASSERT_EQ(4, results["hits"].size()); tsl::htrie_map schema = collection1->get_schema(); + ASSERT_EQ(schema.count("product_id_sequence_id"), 1); // recreate collection manager to ensure that it restores the records from the disk backed store collectionManager.dispose(); @@ -472,6 +477,7 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) { ASSERT_TRUE(restored_schema.at("person").nested); ASSERT_EQ(2, restored_schema.at("person").nested_array); ASSERT_EQ(128, restored_schema.at("vec").num_dim); + ASSERT_EQ(restored_schema.count("product_id_sequence_id"), 1); ASSERT_TRUE(collection1->get_enable_nested_fields()); From 11c9d98a4338b9a82085ed5536df0552163b0146 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 18 Jan 2023 17:47:52 +0530 Subject: [PATCH 06/27] Remove `foo_sequence_id` from collection summary. --- src/collection.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/collection.cpp b/src/collection.cpp index ab162605..da1f0e5a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -159,8 +159,14 @@ nlohmann::json Collection::get_summary_json() const { } nlohmann::json fields_arr; + const std::regex sequence_id_pattern(".*_sequence_id$"); for(const field & coll_field: fields) { + if (std::regex_match(coll_field.name, sequence_id_pattern)) { + // Don't add foo_sequence_id field. + continue; + } + nlohmann::json field_json; field_json[fields::name] = coll_field.name; field_json[fields::type] = coll_field.type; From eef346b29f0ab62c556aca934ce07e9b926207b6 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 18 Jan 2023 18:12:25 +0530 Subject: [PATCH 07/27] Index document containing a reference field. --- src/index.cpp | 44 +++++++++ test/collection_join_test.cpp | 167 ++++++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+) diff --git a/src/index.cpp b/src/index.cpp index a71c1498..2db453de 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -23,6 +23,7 @@ #include #include #include "logger.h" +#include #define RETURN_CIRCUIT_BREAKER if((std::chrono::duration_cast( \ std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { \ @@ -430,6 +431,49 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 continue; } + if (!a_field.reference.empty()) { + // Add foo_sequence_id field in the document. + + std::vector tokens; + StringUtils::split(a_field.reference, tokens, "."); + + if (tokens.size() < 2) { + return Option<>(400, "Invalid reference `" + a_field.reference + "`."); + } + + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(tokens[0]); + if (collection == nullptr) { + return Option<>(400, "Referenced collection `" + tokens[0] + "` not found."); + } + + if (collection->get_schema().count(tokens[1]) == 0) { + return Option<>(400, "Referenced field `" + tokens[1] + "` not found in the collection `" + + tokens[0] + "`."); + } + + auto referenced_field_name = tokens[1]; + if (!collection->get_schema().at(referenced_field_name).index) { + return Option<>(400, "Referenced field `" + tokens[1] + "` in the collection `" + + tokens[0] + "` must be indexed."); + } + + std::vector> documents; + auto value = document[a_field.name].get(); + collection->get_filter_ids(referenced_field_name + ":=" + value, documents); + + if (documents[0].first != 1) { + auto match = " `" + referenced_field_name + "` = `" + value + "` "; + return Option<>(400, documents[0].first < 1 ? + "Referenced document having" + match + "not found in the collection `" + tokens[0] + "`." : + "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); + } + + document[a_field.name + "_sequence_id"] = collection->get_seq_id_collection_prefix() + std::to_string(*(documents[0].second)); + + delete [] documents[0].second; + } + if(document.count(field_name) == 0) { return Option<>(400, "Field `" + field_name + "` has been declared in the schema, " "but is not found in the document."); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index f57895d4..e79f76d4 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -98,4 +98,171 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { ASSERT_TRUE(schema.at("product_id_sequence_id").index); collectionManager.drop_collection("Customers"); +} + +TEST_F(CollectionJoinTest, IndexReferenceField) { + auto products_schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string", "index": false, "optional": true}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + auto collection_create_op = collectionManager.create_collection(products_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + auto customers_schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "foo"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(customers_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + nlohmann::json customer_json = R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 143, + "product_id": "a" + })"_json; + + auto customer_collection = collection_create_op.get(); + auto add_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Invalid reference `foo`.", add_op.error()); + collectionManager.drop_collection("Customers"); + + customers_schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "products.product_id"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(customers_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + customer_collection = collection_create_op.get(); + add_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Referenced collection `products` not found.", add_op.error()); + collectionManager.drop_collection("Customers"); + + customers_schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.id"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(customers_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + customer_collection = collection_create_op.get(); + add_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Referenced field `id` not found in the collection `Products`.", add_op.error()); + collectionManager.drop_collection("Customers"); + + customers_schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(customers_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + customer_collection = collection_create_op.get(); + add_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Referenced field `product_id` in the collection `Products` must be indexed.", add_op.error()); + + collectionManager.drop_collection("Products"); + products_schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(products_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + add_op = customer_collection->add(customer_json.dump()); + ASSERT_EQ("Referenced document having `product_id` = `a` not found in the collection `Products`.", add_op.error()); + + std::vector products = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_a", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + for (auto const &json: products){ + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + customer_json["product_id"] = "product_a"; + add_op = customer_collection->add(customer_json.dump()); + ASSERT_EQ("Multiple documents having `product_id` = `product_a` found in the collection `Products`.", add_op.error()); + + collectionManager.drop_collection("Products"); + products[1]["product_id"] = "product_b"; + products_schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(products_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: products){ + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + add_op = customer_collection->add(customer_json.dump()); + ASSERT_TRUE(add_op.ok()); + + auto result = customer_collection->search("*", {"customer_id"}, "", {}, {}, {0}).get(); + ASSERT_EQ(result["hits"][0]["document"].count("product_id"), 1); +// ASSERT_EQ(result["hits"][0]["document"].count("product_id_sequence_id"), 0); + + collectionManager.drop_collection("Customers"); + collectionManager.drop_collection("Products"); } \ No newline at end of file From 0acd2c06c371dc8012df5510903a70724785ada6 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 19 Jan 2023 11:25:43 +0530 Subject: [PATCH 08/27] Serialize sequence id. --- src/index.cpp | 3 ++- test/collection_join_test.cpp | 23 +++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 2db453de..d14db052 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -469,7 +469,8 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); } - document[a_field.name + "_sequence_id"] = collection->get_seq_id_collection_prefix() + std::to_string(*(documents[0].second)); + document[a_field.name + "_sequence_id"] = collection->get_seq_id_collection_prefix() + "_" + + StringUtils::serialize_uint32_t(*(documents[0].second)); delete [] documents[0].second; } diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index e79f76d4..9964c864 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -255,13 +255,28 @@ TEST_F(CollectionJoinTest, IndexReferenceField) { } ASSERT_TRUE(add_op.ok()); } - + collectionManager.drop_collection("Customers"); + customers_schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(customers_schema_json); + ASSERT_TRUE(collection_create_op.ok()); add_op = customer_collection->add(customer_json.dump()); ASSERT_TRUE(add_op.ok()); + ASSERT_EQ(customer_collection->get("0").get().count("product_id_sequence_id"), 1); - auto result = customer_collection->search("*", {"customer_id"}, "", {}, {}, {0}).get(); - ASSERT_EQ(result["hits"][0]["document"].count("product_id"), 1); -// ASSERT_EQ(result["hits"][0]["document"].count("product_id_sequence_id"), 0); + nlohmann::json document; + auto get_op = customer_collection->get_document_from_store(customer_collection->get("0").get()["product_id_sequence_id"].get(), document); + ASSERT_TRUE(get_op.ok()); + ASSERT_EQ(document.count("product_id"), 1); + ASSERT_EQ(document["product_id"], "product_a"); collectionManager.drop_collection("Customers"); collectionManager.drop_collection("Products"); From c4730c60b38f7f5bbfae55a3e29400a4fc13b518 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 19 Jan 2023 11:27:52 +0530 Subject: [PATCH 09/27] Store `foo_sequence_id` in collection's meta-data. --- include/field.h | 7 ------- src/collection_manager.cpp | 4 ---- test/collection_manager_test.cpp | 11 +++++++++++ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/field.h b/include/field.h index 0b6307a2..ca29fa43 100644 --- a/include/field.h +++ b/include/field.h @@ -9,7 +9,6 @@ #include #include #include "json.hpp" -#include namespace field_types { // first field value indexed will determine the type @@ -279,17 +278,11 @@ struct field { const std::string & default_sorting_field, nlohmann::json& fields_json) { bool found_default_sorting_field = false; - const std::regex sequence_id_pattern(".*_sequence_id$"); // Check for duplicates in field names std::map> unique_fields; for(const field & field: fields) { - if (std::regex_match(field.name, sequence_id_pattern)) { - // Don't add foo_sequence_id field. - continue; - } - unique_fields[field.name].push_back(&field); if(field.name == "id") { diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f1540932..0fb4d5be 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -80,10 +80,6 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection } fields.push_back(f); - - if (!f.reference.empty()) { - fields.emplace_back(field(f.name + "_sequence_id", "string", false, f.optional, true)); - } } std::string default_sorting_field = collection_meta[Collection::COLLECTION_DEFAULT_SORTING_FIELD_KEY].get(); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index dd4debd2..1bd74a27 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -225,6 +225,17 @@ TEST_F(CollectionManagerTest, CollectionCreation) { "sort":false, "type":"string", "reference":"Products.product_id" + }, + { + "facet":false, + "index":true, + "infix":false, + "locale":"", + "name":"product_id_sequence_id", + "nested":false, + "optional":true, + "sort":false, + "type":"string" } ], "id":0, From 3bc7275e2322a38118bad85cc19aa7b0a59f1c40 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 19 Jan 2023 11:39:27 +0530 Subject: [PATCH 10/27] Fix memory leak. --- src/index.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/index.cpp b/src/index.cpp index d14db052..c9f79b38 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -463,6 +463,7 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 collection->get_filter_ids(referenced_field_name + ":=" + value, documents); if (documents[0].first != 1) { + delete [] documents[0].second; auto match = " `" + referenced_field_name + "` = `" + value + "` "; return Option<>(400, documents[0].first < 1 ? "Referenced document having" + match + "not found in the collection `" + tokens[0] + "`." : From ebfbf4f48d62af40e551d8725fbd68b8f6f73921 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Sun, 22 Jan 2023 12:02:29 +0530 Subject: [PATCH 11/27] Filter by reference. --- include/collection.h | 2 + include/field.h | 5 +- src/collection.cpp | 16 +++++ src/field.cpp | 30 ++++++++- src/index.cpp | 54 +++++++++++++++- src/string_utils.cpp | 42 +++++++++++- test/collection_join_test.cpp | 117 +++++++++++++++++++++++++++++++++- test/string_utils_test.cpp | 4 ++ 8 files changed, 260 insertions(+), 10 deletions(-) diff --git a/include/collection.h b/include/collection.h index 720631c0..71b879b1 100644 --- a/include/collection.h +++ b/include/collection.h @@ -436,6 +436,8 @@ public: Option get_filter_ids(const std::string & filter_query, std::vector>& index_ids) const; + Option validate_reference_filter(const std::string& filter_query) const; + Option get(const std::string & id) const; Option remove(const std::string & id, bool remove_from_store = true); diff --git a/include/field.h b/include/field.h index ca29fa43..ee69bff5 100644 --- a/include/field.h +++ b/include/field.h @@ -77,7 +77,7 @@ struct field { static constexpr int VAL_UNKNOWN = 2; - std::string reference; // Reference to another collection. + std::string reference; // Foo.bar (reference to bar field in Foo collection). field() {} @@ -448,6 +448,9 @@ struct filter { // aggregated and then this flag is checked if negation on the aggregated result is required. bool apply_not_equals = false; + // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` + std::string referenced_collection_name; + static const std::string RANGE_OPERATOR() { return ".."; } diff --git a/src/collection.cpp b/src/collection.cpp index da1f0e5a..516a8bf3 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2362,6 +2362,22 @@ Option Collection::get_filter_ids(const std::string & filter_query, return Option(true); } +Option Collection::validate_reference_filter(const std::string& filter_query) const { + std::shared_lock lock(mutex); + + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query(filter_query, search_schema, + store, doc_id_prefix, filter_tree_root); + + if(!filter_op.ok()) { + return filter_op; + } + + delete filter_tree_root; + return Option(true); +} + bool Collection::facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document, std::string &value) const { diff --git a/src/field.cpp b/src/field.cpp index 4275577a..f0bd4eaa 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -2,6 +2,7 @@ #include "field.h" #include "magic_enum.hpp" #include +#include Option filter::parse_geopoint_filter_value(std::string& raw_value, const std::string& format_err_msg, @@ -408,9 +409,32 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); } else { filter filter_exp; - Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); - if (!toFilter_op.ok()) { - return toFilter_op; + + // Expected value: $Collection(...) + bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); + if (is_referenced_filter) { + size_t parenthesis_index = expression.find('('); + + std::string collection_name = expression.substr(1, parenthesis_index - 1); + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(collection_name); + if (collection == nullptr) { + return Option(400, "Referenced collection `" + collection_name + "` not found."); + } + + filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; + filter_exp.referenced_collection_name = collection_name; + + auto op = collection->validate_reference_filter(filter_exp.field_name); + if (!op.ok()) { + return Option(400, "Failed to parse reference filter on `" + collection_name + + "` collection: " + op.error()); + } + } else { + Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); + if (!toFilter_op.ok()) { + return toFilter_op; + } } filter_node = new filter_node_t(filter_exp); diff --git a/src/index.cpp b/src/index.cpp index c9f79b38..e21dddcc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -470,8 +470,7 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); } - document[a_field.name + "_sequence_id"] = collection->get_seq_id_collection_prefix() + "_" + - StringUtils::serialize_uint32_t(*(documents[0].second)); + document[a_field.name + "_sequence_id"] = StringUtils::serialize_uint32_t(*(documents[0].second)); delete [] documents[0].second; } @@ -1668,6 +1667,57 @@ void Index::do_filtering(uint32_t*& filter_ids, // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + + std::vector> documents; + auto op = collection->get_filter_ids(a_filter.field_name, documents); + if (!op.ok()) { + return; + } + + if (documents[0].first > 0) { + const field* reference_field = nullptr; + for (auto const& f: collection->get_fields()) { + auto this_collection_name = cm.get_collection_with_id(collection_id)->get_name(); + if (!f.reference.empty() && + f.reference.find(this_collection_name) == 0 && + f.reference.find('.') == this_collection_name.size()) { + reference_field = &f; + break; + } + } + + if (reference_field == nullptr) { + return; + } + + std::vector result_ids; + for (size_t i = 0; i < documents[0].first; i++) { + uint32_t seq_id = *(documents[0].second + i); + + nlohmann::json document; + auto op = collection->get_document_from_store(seq_id, document); + if (!op.ok()) { + return; + } + + result_ids.push_back(StringUtils::deserialize_uint32_t(document[reference_field->name + "_sequence_id"].get())); + } + + filter_ids = new uint32[result_ids.size()]; + std::sort(result_ids.begin(), result_ids.end()); + std::copy(result_ids.begin(), result_ids.end(), filter_ids); + filter_ids_length = result_ids.size(); + } + + delete [] documents[0].second; + return; + } + if (a_filter.field_name == "id") { // we handle `ids` separately std::vector result_ids; diff --git a/src/string_utils.cpp b/src/string_utils.cpp index e7893004..a9409400 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -349,9 +349,40 @@ size_t StringUtils::get_num_chars(const std::string& s) { return j; } +Option parse_reference_filter(const std::string& filter_query, std::queue& tokens, size_t& index) { + auto error = Option(400, "Could not parse the reference filter."); + if (filter_query[index] != '$') { + return error; + } + + int start_index = index; + auto size = filter_query.size(); + while(++index < size && filter_query[index] != '(') {} + + if (index >= size) { + return error; + } + + int parenthesis_count = 1; + while (++index < size && parenthesis_count > 0) { + if (filter_query[index] == '(') { + parenthesis_count++; + } else if (filter_query[index] == ')') { + parenthesis_count--; + } + } + + if (parenthesis_count != 0) { + return error; + } + + tokens.push(filter_query.substr(start_index, index - start_index)); + return Option(true); +} + Option StringUtils::tokenize_filter_query(const std::string& filter_query, std::queue& tokens) { auto size = filter_query.size(); - for (auto i = 0; i < size;) { + for (size_t i = 0; i < size;) { auto c = filter_query[i]; if (c == ' ') { i++; @@ -377,6 +408,15 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, tokens.push("||"); i += 2; } else { + // Reference filter would start with $ symbol. + if (c == '$') { + auto op = parse_reference_filter(filter_query, tokens, i); + if (!op.ok()) { + return op; + } + continue; + } + std::stringstream ss; bool inBacktick = false; bool preceding_colon = false; diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 9964c864..26a7d476 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -93,14 +93,14 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { ASSERT_EQ(schema.at("customer_name").reference, ""); ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); - // Index a `foo_sequence_id` field for `foo` reference field. + // Add a `foo_sequence_id` field in the schema for `foo` reference field. ASSERT_EQ(schema.count("product_id_sequence_id"), 1); ASSERT_TRUE(schema.at("product_id_sequence_id").index); collectionManager.drop_collection("Customers"); } -TEST_F(CollectionJoinTest, IndexReferenceField) { +TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { auto products_schema_json = R"({ "name": "Products", @@ -272,12 +272,123 @@ TEST_F(CollectionJoinTest, IndexReferenceField) { ASSERT_TRUE(add_op.ok()); ASSERT_EQ(customer_collection->get("0").get().count("product_id_sequence_id"), 1); + // Referenced document should be accessible from Customers collection. + auto sequence_id = collectionManager.get_collection("Products")->get_seq_id_collection_prefix() + "_" + + customer_collection->get("0").get()["product_id_sequence_id"].get(); nlohmann::json document; - auto get_op = customer_collection->get_document_from_store(customer_collection->get("0").get()["product_id_sequence_id"].get(), document); + auto get_op = customer_collection->get_document_from_store(sequence_id, document); ASSERT_TRUE(get_op.ok()); ASSERT_EQ(document.count("product_id"), 1); ASSERT_EQ(document["product_id"], "product_a"); + ASSERT_EQ(document["product_name"], "shampoo"); collectionManager.drop_collection("Customers"); collectionManager.drop_collection("Products"); +} + +TEST_F(CollectionJoinTest, FilterByReferenceField) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_b", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + documents = { + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 143, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 73.5, + "product_id": "product_b" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 75, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 140, + "product_id": "product_b" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + auto coll = collectionManager.get_collection("Products"); + auto search_op = coll->search("s", {"product_name"}, "$foo:=customer_a", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Could not parse the reference filter."); + + search_op = coll->search("s", {"product_name"}, "$foo(:=customer_a", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Could not parse the reference filter."); + + search_op = coll->search("s", {"product_name"}, "$foo(:=customer_a)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Referenced collection `foo` not found."); + + search_op = coll->search("s", {"product_name"}, "$Customers(foo:=customer_a)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ(search_op.error(), "Failed to parse reference filter on `Customers` collection: Could not find a filter field named `foo` in the schema."); + + auto result = coll->search("s", {"product_name"}, "$Customers(customer_id:=customer_a && product_price:<100)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get(); + + ASSERT_EQ(1, result["found"].get()); + ASSERT_EQ(1, result["hits"].size()); + ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); } \ No newline at end of file diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index 1d2e0246..31acfafb 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -393,4 +393,8 @@ TEST(StringUtilsTest, TokenizeFilterQuery) { filter_query = "((age:<5||age:>10)&&location:(48.906,2.343,5mi))||tags:AT&T"; tokenList = {"(", "(", "age:<5", "||", "age:>10", ")", "&&", "location:(48.906,2.343,5mi)", ")", "||", "tags:AT&T"}; tokenizeTestHelper(filter_query, tokenList); + + filter_query = "((age: <5 || age: >10) && category:= [shoes]) && $Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + tokenList = {"(", "(", "age: <5", "||", "age: >10", ")", "&&", "category:= [shoes]", ")", "&&", "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"}; + tokenizeTestHelper(filter_query, tokenList); } From 6c5662bc955e2901be1e03e46f9a9869ae0a0d3b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 24 Jan 2023 10:57:29 +0530 Subject: [PATCH 12/27] Optimize reference filtering. --- include/collection.h | 4 ++ include/index.h | 4 ++ src/collection.cpp | 35 +++++++++++++++ src/field.cpp | 2 +- src/index.cpp | 68 ++++++++++++----------------- test/collection_join_test.cpp | 74 ++++++++++++++++---------------- test/collection_manager_test.cpp | 16 ++++--- 7 files changed, 120 insertions(+), 83 deletions(-) diff --git a/include/collection.h b/include/collection.h index 71b879b1..6d25bcfa 100644 --- a/include/collection.h +++ b/include/collection.h @@ -436,6 +436,10 @@ public: Option get_filter_ids(const std::string & filter_query, std::vector>& index_ids) const; + Option get_reference_filter_ids(const std::string & filter_query, + const std::string & collection_name, + std::pair& reference_index_ids) const; + Option validate_reference_filter(const std::string& filter_query) const; Option get(const std::string & id) const; diff --git a/include/index.h b/include/index.h index 999993b6..a6b161cd 100644 --- a/include/index.h +++ b/include/index.h @@ -715,6 +715,10 @@ public: uint32_t& filter_ids_length, filter_node_t const* const& filter_tree_root) const; + void do_reference_filtering_with_lock(std::pair& reference_index_ids, + filter_node_t const* const& filter_tree_root, + const std::string& reference_field_name) const; + void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); // the following methods are not synchronized because their parent calls are synchronized or they are const/static diff --git a/src/collection.cpp b/src/collection.cpp index 516a8bf3..7716b5ac 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2362,6 +2362,41 @@ Option Collection::get_filter_ids(const std::string & filter_query, return Option(true); } +Option Collection::get_reference_filter_ids(const std::string & filter_query, + const std::string & collection_name, + std::pair& reference_index_ids) const { + std::shared_lock lock(mutex); + + std::string reference_field_name; + for (auto const& field: fields) { + if (!field.reference.empty() && + field.reference.find(collection_name) == 0 && + field.reference.find('.') == collection_name.size()) { + reference_field_name = field.name; + break; + } + } + + if (reference_field_name.empty()) { + return Option(400, "Could not find any field in `" + name + "` referencing the collection `" + + collection_name + "`."); + } + + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query(filter_query, search_schema, + store, doc_id_prefix, filter_tree_root); + if(!filter_op.ok()) { + return filter_op; + } + + reference_field_name += "_sequence_id"; + index->do_reference_filtering_with_lock(reference_index_ids, filter_tree_root, reference_field_name); + + delete filter_tree_root; + return Option(true); +} + Option Collection::validate_reference_filter(const std::string& filter_query) const { std::shared_lock lock(mutex); diff --git a/src/field.cpp b/src/field.cpp index f0bd4eaa..36ab953e 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -723,7 +723,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso if (!field_json[fields::reference].get().empty()) { the_fields.emplace_back( - field(field_json[fields::name].get() + "_sequence_id", "string", false, + field(field_json[fields::name].get() + "_sequence_id", "int64", false, field_json[fields::optional], true) ); } diff --git a/src/index.cpp b/src/index.cpp index e21dddcc..64ad5a66 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -470,7 +470,7 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); } - document[a_field.name + "_sequence_id"] = StringUtils::serialize_uint32_t(*(documents[0].second)); + document[a_field.name + "_sequence_id"] = *(documents[0].second); delete [] documents[0].second; } @@ -1665,7 +1665,7 @@ void Index::do_filtering(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t const* const root) const { // auto begin = std::chrono::high_resolution_clock::now(); - const filter a_filter = root->filter_exp; +/**/ const filter a_filter = root->filter_exp; bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); if (is_referenced_filter) { @@ -1673,48 +1673,16 @@ void Index::do_filtering(uint32_t*& filter_ids, auto& cm = CollectionManager::get_instance(); auto collection = cm.get_collection(a_filter.referenced_collection_name); - std::vector> documents; - auto op = collection->get_filter_ids(a_filter.field_name, documents); + std::pair documents; + auto op = collection->get_reference_filter_ids(a_filter.field_name, + cm.get_collection_with_id(collection_id)->get_name(), + documents); if (!op.ok()) { return; } - if (documents[0].first > 0) { - const field* reference_field = nullptr; - for (auto const& f: collection->get_fields()) { - auto this_collection_name = cm.get_collection_with_id(collection_id)->get_name(); - if (!f.reference.empty() && - f.reference.find(this_collection_name) == 0 && - f.reference.find('.') == this_collection_name.size()) { - reference_field = &f; - break; - } - } - - if (reference_field == nullptr) { - return; - } - - std::vector result_ids; - for (size_t i = 0; i < documents[0].first; i++) { - uint32_t seq_id = *(documents[0].second + i); - - nlohmann::json document; - auto op = collection->get_document_from_store(seq_id, document); - if (!op.ok()) { - return; - } - - result_ids.push_back(StringUtils::deserialize_uint32_t(document[reference_field->name + "_sequence_id"].get())); - } - - filter_ids = new uint32[result_ids.size()]; - std::sort(result_ids.begin(), result_ids.end()); - std::copy(result_ids.begin(), result_ids.end(), filter_ids); - filter_ids_length = result_ids.size(); - } - - delete [] documents[0].second; + filter_ids_length = documents.first; + filter_ids = documents.second; return; } @@ -2099,6 +2067,26 @@ void Index::do_filtering_with_lock(uint32_t*& filter_ids, recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); } +void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, + filter_node_t const* const& filter_tree_root, + const std::string& reference_field_name) const { + std::shared_lock lock(mutex); + recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); + + std::vector vector; + vector.reserve(reference_index_ids.first); + + for (uint32_t i = 0; i < reference_index_ids.first; i++) { + auto filtered_doc_id = *(reference_index_ids.second + i); + + // Extract the sequence_id from the reference field. + vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id)); + } + + std::sort(vector.begin(), vector.end()); + std::copy(vector.begin(), vector.end(), reference_index_ids.second); +} + void Index::run_search(search_args* search_params) { search(search_params->field_query_tokens, search_params->search_fields, diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 26a7d476..e7e5636f 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -101,18 +101,6 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { } TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { - auto products_schema_json = - R"({ - "name": "Products", - "fields": [ - {"name": "product_id", "type": "string", "index": false, "optional": true}, - {"name": "product_name", "type": "string"}, - {"name": "product_description", "type": "string"} - ] - })"_json; - auto collection_create_op = collectionManager.create_collection(products_schema_json); - ASSERT_TRUE(collection_create_op.ok()); - auto customers_schema_json = R"({ "name": "Customers", @@ -123,7 +111,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { {"name": "product_id", "type": "string", "reference": "foo"} ] })"_json; - collection_create_op = collectionManager.create_collection(customers_schema_json); + auto collection_create_op = collectionManager.create_collection(customers_schema_json); ASSERT_TRUE(collection_create_op.ok()); nlohmann::json customer_json = R"({ @@ -134,9 +122,9 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { })"_json; auto customer_collection = collection_create_op.get(); - auto add_op = customer_collection->add(customer_json.dump()); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Invalid reference `foo`.", add_op.error()); + auto add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_doc_op.ok()); + ASSERT_EQ("Invalid reference `foo`.", add_doc_op.error()); collectionManager.drop_collection("Customers"); customers_schema_json = @@ -153,9 +141,9 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { ASSERT_TRUE(collection_create_op.ok()); customer_collection = collection_create_op.get(); - add_op = customer_collection->add(customer_json.dump()); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Referenced collection `products` not found.", add_op.error()); + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_doc_op.ok()); + ASSERT_EQ("Referenced collection `products` not found.", add_doc_op.error()); collectionManager.drop_collection("Customers"); customers_schema_json = @@ -170,11 +158,23 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { })"_json; collection_create_op = collectionManager.create_collection(customers_schema_json); ASSERT_TRUE(collection_create_op.ok()); - customer_collection = collection_create_op.get(); - add_op = customer_collection->add(customer_json.dump()); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Referenced field `id` not found in the collection `Products`.", add_op.error()); + + auto products_schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string", "index": false, "optional": true}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(products_schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_doc_op.ok()); + ASSERT_EQ("Referenced field `id` not found in the collection `Products`.", add_doc_op.error()); collectionManager.drop_collection("Customers"); customers_schema_json = @@ -191,9 +191,9 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { ASSERT_TRUE(collection_create_op.ok()); customer_collection = collection_create_op.get(); - add_op = customer_collection->add(customer_json.dump()); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Referenced field `product_id` in the collection `Products` must be indexed.", add_op.error()); + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_doc_op.ok()); + ASSERT_EQ("Referenced field `product_id` in the collection `Products` must be indexed.", add_doc_op.error()); collectionManager.drop_collection("Products"); products_schema_json = @@ -208,8 +208,8 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { collection_create_op = collectionManager.create_collection(products_schema_json); ASSERT_TRUE(collection_create_op.ok()); - add_op = customer_collection->add(customer_json.dump()); - ASSERT_EQ("Referenced document having `product_id` = `a` not found in the collection `Products`.", add_op.error()); + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_EQ("Referenced document having `product_id` = `a` not found in the collection `Products`.", add_doc_op.error()); std::vector products = { R"({ @@ -232,8 +232,8 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { } customer_json["product_id"] = "product_a"; - add_op = customer_collection->add(customer_json.dump()); - ASSERT_EQ("Multiple documents having `product_id` = `product_a` found in the collection `Products`.", add_op.error()); + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_EQ("Multiple documents having `product_id` = `product_a` found in the collection `Products`.", add_doc_op.error()); collectionManager.drop_collection("Products"); products[1]["product_id"] = "product_b"; @@ -268,15 +268,17 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { })"_json; collection_create_op = collectionManager.create_collection(customers_schema_json); ASSERT_TRUE(collection_create_op.ok()); - add_op = customer_collection->add(customer_json.dump()); - ASSERT_TRUE(add_op.ok()); + + customer_collection = collection_create_op.get(); + add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_TRUE(add_doc_op.ok()); ASSERT_EQ(customer_collection->get("0").get().count("product_id_sequence_id"), 1); - // Referenced document should be accessible from Customers collection. - auto sequence_id = collectionManager.get_collection("Products")->get_seq_id_collection_prefix() + "_" + - customer_collection->get("0").get()["product_id_sequence_id"].get(); nlohmann::json document; - auto get_op = customer_collection->get_document_from_store(sequence_id, document); + // Referenced document's sequence_id must be valid. + auto get_op = collectionManager.get_collection("Products")->get_document_from_store( + customer_collection->get("0").get()["product_id_sequence_id"].get(), + document); ASSERT_TRUE(get_op.ok()); ASSERT_EQ(document.count("product_id"), 1); ASSERT_EQ(document["product_id"], "product_a"); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 1bd74a27..b2410558 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -74,9 +74,11 @@ TEST_F(CollectionManagerTest, CollectionCreation) { ASSERT_EQ(0, collection1->get_collection_id()); ASSERT_EQ(0, collection1->get_next_seq_id()); ASSERT_EQ(facet_fields_expected, collection1->get_facet_fields()); - ASSERT_EQ(2, collection1->get_sort_fields().size()); + // product_id_sequence_id is also included + ASSERT_EQ(3, collection1->get_sort_fields().size()); ASSERT_EQ("location", collection1->get_sort_fields()[0].name); - ASSERT_EQ("points", collection1->get_sort_fields()[1].name); + ASSERT_EQ("product_id_sequence_id", collection1->get_sort_fields()[1].name); + ASSERT_EQ("points", collection1->get_sort_fields()[2].name); ASSERT_EQ(schema.size(), collection1->get_schema().size()); ASSERT_EQ("points", collection1->get_default_sorting_field()); ASSERT_EQ(false, schema.at("not_stored").index); @@ -234,8 +236,8 @@ TEST_F(CollectionManagerTest, CollectionCreation) { "name":"product_id_sequence_id", "nested":false, "optional":true, - "sort":false, - "type":"string" + "sort":true, + "type":"int64" } ], "id":0, @@ -473,9 +475,11 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) { ASSERT_EQ(0, collection1->get_collection_id()); ASSERT_EQ(18, collection1->get_next_seq_id()); ASSERT_EQ(facet_fields_expected, collection1->get_facet_fields()); - ASSERT_EQ(2, collection1->get_sort_fields().size()); + // product_id_sequence_id is also included + ASSERT_EQ(3, collection1->get_sort_fields().size()); ASSERT_EQ("location", collection1->get_sort_fields()[0].name); - ASSERT_EQ("points", collection1->get_sort_fields()[1].name); + ASSERT_EQ("product_id_sequence_id", collection1->get_sort_fields()[1].name); + ASSERT_EQ("points", collection1->get_sort_fields()[2].name); ASSERT_EQ(schema.size(), collection1->get_schema().size()); ASSERT_EQ("points", collection1->get_default_sorting_field()); From 5c5f43195c759f9b4255ea7e4f6f7b461fd1d01e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 27 Jan 2023 12:57:13 +0530 Subject: [PATCH 13/27] Add `Index::rearranging_recursive_filter`. --- include/field.h | 2 + include/index.h | 20 ++-- src/index.cpp | 156 ++++++++++++++++++++++------- test/collection_filtering_test.cpp | 56 +++++++++++ 4 files changed, 188 insertions(+), 46 deletions(-) diff --git a/include/field.h b/include/field.h index ee69bff5..34d0c6e7 100644 --- a/include/field.h +++ b/include/field.h @@ -536,6 +536,7 @@ struct filter_node_t { bool isOperator; filter_node_t* left; filter_node_t* right; + std::pair match_index_ids; filter_node_t(filter filter_exp) : filter_exp(std::move(filter_exp)), @@ -552,6 +553,7 @@ struct filter_node_t { right(right) {} ~filter_node_t() { + delete[] match_index_ids.second; delete left; delete right; } diff --git a/include/index.h b/include/index.h index a6b161cd..f8690e99 100644 --- a/include/index.h +++ b/include/index.h @@ -99,7 +99,7 @@ struct search_args { std::vector field_query_tokens; std::vector search_fields; const text_match_type_t match_type; - const filter_node_t* filter_tree_root; + filter_node_t* filter_tree_root; std::vector& facets; std::vector>& included_ids; std::vector excluded_ids; @@ -484,14 +484,16 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t const* const root) const; + void do_filtering(filter_node_t* const root) const; + + void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const; void recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const root, - const bool enable_short_circuit) const; + filter_node_t* const root, + const bool enable_short_circuit = false) const; + + void get_filter_matches(filter_node_t* const root, std::vector>& vec) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -653,7 +655,7 @@ public: void search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t const* const& filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, @@ -713,10 +715,10 @@ public: void do_filtering_with_lock( uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const& filter_tree_root) const; + filter_node_t* filter_tree_root) const; void do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t const* const& filter_tree_root, + filter_node_t* filter_tree_root, const std::string& reference_field_name) const; void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); diff --git a/src/index.cpp b/src/index.cpp index 64ad5a66..5e722fdd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1661,11 +1661,9 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t const* const root) const { +void Index::do_filtering(filter_node_t* const root) const { // auto begin = std::chrono::high_resolution_clock::now(); -/**/ const filter a_filter = root->filter_exp; + const filter a_filter = root->filter_exp; bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); if (is_referenced_filter) { @@ -1673,16 +1671,12 @@ void Index::do_filtering(uint32_t*& filter_ids, auto& cm = CollectionManager::get_instance(); auto collection = cm.get_collection(a_filter.referenced_collection_name); - std::pair documents; auto op = collection->get_reference_filter_ids(a_filter.field_name, cm.get_collection_with_id(collection_id)->get_name(), - documents); + root->match_index_ids); if (!op.ok()) { return; } - - filter_ids_length = documents.first; - filter_ids = documents.second; return; } @@ -1695,17 +1689,9 @@ void Index::do_filtering(uint32_t*& filter_ids, std::sort(result_ids.begin(), result_ids.end()); - if (filter_ids_length == 0) { - filter_ids = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), filter_ids); - filter_ids_length = result_ids.size(); - } else { - uint32_t* filtered_results = nullptr; - filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0], - result_ids.size(), &filtered_results); - delete[] filter_ids; - filter_ids = filtered_results; - } + root->match_index_ids.second = new uint32[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), root->match_index_ids.second); + root->match_index_ids.first = result_ids.size(); return; } @@ -2005,8 +1991,8 @@ void Index::do_filtering(uint32_t*& filter_ids, result_ids_len = to_include_ids_len; } - filter_ids = result_ids; - filter_ids_length = result_ids_len; + root->match_index_ids.first = result_ids_len; + root->match_index_ids.second = result_ids; /*long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() @@ -2015,38 +2001,131 @@ void Index::do_filtering(uint32_t*& filter_ids, LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::recursive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - const filter_node_t* root, - const bool enable_short_circuit) const { +void Index::get_filter_matches(filter_node_t* const root, std::vector>& vec) const { if (root == nullptr) { return; } + if (root->isOperator && root->filter_operator == OR) { + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + } + + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + } + + root->match_index_ids.first = ArrayUtils::or_scalar( + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &(root->match_index_ids.second)); + + delete[] l_filter_ids; + delete[] r_filter_ids; + + vec.emplace_back(root->match_index_ids.first, root); + } else if (root->left == nullptr && root->right == nullptr) { + do_filtering(root); + vec.emplace_back(root->match_index_ids.first, root); + } else { + get_filter_matches(root->left, vec); + get_filter_matches(root->right, vec); + } +} + +void evaluate_filter_tree(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + bool is_rearranged, + std::vector>& vec, + size_t& index) { + if (root == nullptr) { + return; + } + + if (root->isOperator) { + if (root->filter_operator == AND) { + + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + evaluate_filter_tree(l_filter_ids, l_filter_ids_length, root->left, is_rearranged, vec, index); + } + + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + evaluate_filter_tree(r_filter_ids, r_filter_ids_length, root->right, is_rearranged, vec, index); + } + + root->match_index_ids.first = ArrayUtils::and_scalar( + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &(root->match_index_ids.second)); + + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; + } else { + filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; + filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; + index++; + } + } else if (root->left == nullptr && root->right == nullptr) { + filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; + filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; + index++; + } else { + // malformed + } +} + +void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const { + std::vector> vec; + get_filter_matches(root, vec); + + bool should_rearrange = vec.size() > 2; + if (should_rearrange) { + std::sort(vec.begin(), vec.end(), + [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first < rhs.first; + }); + } + + size_t index = 0; + evaluate_filter_tree(filter_ids, filter_ids_length, root, should_rearrange, vec, index); +} + +void Index::recursive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const bool enable_short_circuit) const { + if (root == nullptr) { + return; + } uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { recursive_filter(l_filter_ids, l_filter_ids_length, root->left, enable_short_circuit); } - uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { recursive_filter(r_filter_ids, r_filter_ids_length, root->right, enable_short_circuit); } - if (root->isOperator) { uint32_t* filtered_results = nullptr; if (root->filter_operator == AND) { filter_ids_length = ArrayUtils::and_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &filtered_results); } else { filter_ids_length = ArrayUtils::or_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &filtered_results); } delete[] l_filter_ids; @@ -2054,7 +2133,10 @@ void Index::recursive_filter(uint32_t*& filter_ids, filter_ids = filtered_results; } else if (root->left == nullptr && root->right == nullptr) { - do_filtering(filter_ids, filter_ids_length, root); + do_filtering(root); + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; + root->match_index_ids.second = nullptr; } else { // malformed } @@ -2062,13 +2144,13 @@ void Index::recursive_filter(uint32_t*& filter_ids, void Index::do_filtering_with_lock(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const& filter_tree_root) const { + filter_node_t* filter_tree_root) const { std::shared_lock lock(mutex); recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); } void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t const* const& filter_tree_root, + filter_node_t* filter_tree_root, const std::string& reference_field_name) const { std::shared_lock lock(mutex); recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); @@ -2077,7 +2159,7 @@ void Index::do_reference_filtering_with_lock(std::pair& ref vector.reserve(reference_index_ids.first); for (uint32_t i = 0; i < reference_index_ids.first; i++) { - auto filtered_doc_id = *(reference_index_ids.second + i); + auto filtered_doc_id = reference_index_ids.second[i]; // Extract the sequence_id from the reference field. vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id)); @@ -2550,7 +2632,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name void Index::search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t const* const& filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 39194688..d253bf4e 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2536,3 +2536,59 @@ TEST_F(CollectionFilteringTest, FilteringAfterUpsertOnArrayWithSymbolsToIndex) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionFilteringTest, ComplexFilterQuery) { + nlohmann::json schema_json = + R"({ + "name": "ComplexFilterQueryCollection", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int32"}, + {"name": "years", "type": "int32[]"}, + {"name": "rating", "type": "float"} + ] + })"_json; + + auto op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(op.ok()); + auto coll = op.get(); + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::string json_line; + while (std::getline(infile, json_line)) { + auto add_op = coll->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + infile.close(); + + std::vector sort_fields_desc = {sort_by("rating", "DESC")}; + nlohmann::json results = coll->search("Jeremy", {"name"}, "(rating:>=0 && years:>2000) && age:>50", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["hits"].size()); + + results = coll->search("Jeremy", {"name"}, "(age:>50 || rating:>5) && years:<2000", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(2, results["hits"].size()); + + std::vector ids = {"4", "3"}; + for (size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = coll->search("Jeremy", {"name"}, "(age:<50 && rating:10) || (years:>2000 && rating:<5)", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["hits"].size()); + + ids = {"0"}; + for (size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + collectionManager.drop_collection("ComplexFilterQueryCollection"); +} \ No newline at end of file From 85b8c836164261c92a539f7c705abd574b36dad1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 27 Jan 2023 17:46:33 +0530 Subject: [PATCH 14/27] Refactor `Index::rearranging_recursive_filter`. --- src/index.cpp | 76 +++++++++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 5e722fdd..202e35d5 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2006,6 +2006,9 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectorleft, vec); + get_filter_matches(root->right, vec); + if (root->isOperator && root->filter_operator == OR) { uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; @@ -2031,53 +2034,42 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectormatch_index_ids.first, root); } else { - get_filter_matches(root->left, vec); - get_filter_matches(root->right, vec); + // malformed } } -void evaluate_filter_tree(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - bool is_rearranged, - std::vector>& vec, - size_t& index) { +void evaluate_rearranged_filter_tree(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + std::vector>& vec, + size_t& index) { if (root == nullptr) { return; } - if (root->isOperator) { - if (root->filter_operator == AND) { + if (root->isOperator && root->filter_operator == AND) { + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + evaluate_rearranged_filter_tree(l_filter_ids, l_filter_ids_length, root->left, vec, index); + } - uint32_t* l_filter_ids = nullptr; - uint32_t l_filter_ids_length = 0; - if (root->left != nullptr) { - evaluate_filter_tree(l_filter_ids, l_filter_ids_length, root->left, is_rearranged, vec, index); - } + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + evaluate_rearranged_filter_tree(r_filter_ids, r_filter_ids_length, root->right, vec, index); + } - uint32_t* r_filter_ids = nullptr; - uint32_t r_filter_ids_length = 0; - if (root->right != nullptr) { - evaluate_filter_tree(r_filter_ids, r_filter_ids_length, root->right, is_rearranged, vec, index); - } - - root->match_index_ids.first = ArrayUtils::and_scalar( + root->match_index_ids.first = ArrayUtils::and_scalar( l_filter_ids, l_filter_ids_length, r_filter_ids, r_filter_ids_length, &(root->match_index_ids.second)); - filter_ids_length = root->match_index_ids.first; - filter_ids = root->match_index_ids.second; - } else { - filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; - filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; - index++; - } - } else if (root->left == nullptr && root->right == nullptr) { - filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; - filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; - index++; + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; } else { - // malformed + filter_ids_length = vec[index].first; + filter_ids = vec[index].second->match_index_ids.second; + index++; } } @@ -2085,16 +2077,16 @@ void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter std::vector> vec; get_filter_matches(root, vec); - bool should_rearrange = vec.size() > 2; - if (should_rearrange) { - std::sort(vec.begin(), vec.end(), - [](const std::pair& lhs, const std::pair& rhs) { - return lhs.first < rhs.first; - }); - } + std::sort(vec.begin(), vec.end(), + [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first < rhs.first; + }); size_t index = 0; - evaluate_filter_tree(filter_ids, filter_ids_length, root, should_rearrange, vec, index); + evaluate_rearranged_filter_tree(filter_ids, filter_ids_length, root, vec, index); + + // To disable deletion of filter_ids when filter tree is destructed. + root->match_index_ids.second = nullptr; } void Index::recursive_filter(uint32_t*& filter_ids, From 6c19c95af6d99cd86a48a6cbb268ecb1475cd95e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 27 Jan 2023 19:58:06 +0530 Subject: [PATCH 15/27] Add `Index::adaptive_filter`. --- include/field.h | 10 ++++++++- include/index.h | 5 +++++ src/field.cpp | 18 ++++++++++++++-- src/index.cpp | 33 ++++++++++++++++++++++++------ test/collection_filtering_test.cpp | 12 +++++++++++ 5 files changed, 69 insertions(+), 9 deletions(-) diff --git a/include/field.h b/include/field.h index 34d0c6e7..e189648d 100644 --- a/include/field.h +++ b/include/field.h @@ -530,13 +530,20 @@ struct filter { filter_node_t*& root); }; +struct filter_tree_metrics { + int filter_exp_count; + int and_operator_count; + int or_operator_count; +}; + struct filter_node_t { filter filter_exp; FILTER_OPERATOR filter_operator; bool isOperator; filter_node_t* left; filter_node_t* right; - std::pair match_index_ids; + std::pair match_index_ids = {0, nullptr}; + filter_tree_metrics* metrics = nullptr; filter_node_t(filter filter_exp) : filter_exp(std::move(filter_exp)), @@ -553,6 +560,7 @@ struct filter_node_t { right(right) {} ~filter_node_t() { + delete metrics; delete[] match_index_ids.second; delete left; delete right; diff --git a/include/index.h b/include/index.h index f8690e99..db16a24c 100644 --- a/include/index.h +++ b/include/index.h @@ -493,6 +493,11 @@ private: filter_node_t* const root, const bool enable_short_circuit = false) const; + void adaptive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const filter_tree_root, + const bool enable_short_circuit = false) const; + void get_filter_matches(filter_node_t* const root, std::vector>& vec) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, diff --git a/src/field.cpp b/src/field.cpp index 36ab953e..e5aa527c 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -383,7 +383,9 @@ Option toFilter(const std::string expression, Option toParseTree(std::queue& postfix, filter_node_t*& root, const tsl::htrie_map& search_schema, const Store* store, - const std::string& doc_id_prefix) { + const std::string& doc_id_prefix, + int& and_operator_count, + int& or_operator_count) { std::stack nodeStack; while (!postfix.empty()) { @@ -406,6 +408,7 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, auto operandA = nodeStack.top(); nodeStack.pop(); + expression == "&&" ? and_operator_count++ : or_operator_count++; filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); } else { filter filter_exp; @@ -478,11 +481,22 @@ Option filter::parse_filter_query(const std::string& filter_query, return toPostfix_op; } - Option toParseTree_op = toParseTree(postfix, root, search_schema, store, doc_id_prefix); + int postfix_size = (int) postfix.size(), and_operator_count = 0, or_operator_count = 0; + Option toParseTree_op = toParseTree(postfix, + root, + search_schema, + store, + doc_id_prefix, + and_operator_count, + or_operator_count); if (!toParseTree_op.ok()) { return toParseTree_op; } + root->metrics = new filter_tree_metrics{static_cast(postfix_size - (and_operator_count + or_operator_count)), + and_operator_count, + or_operator_count}; + return Option(true); } diff --git a/src/index.cpp b/src/index.cpp index 202e35d5..c4095454 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2013,13 +2013,13 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectorleft != nullptr) { - recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left); } uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right); } root->match_index_ids.first = ArrayUtils::or_scalar( @@ -2102,12 +2102,14 @@ void Index::recursive_filter(uint32_t*& filter_ids, recursive_filter(l_filter_ids, l_filter_ids_length, root->left, enable_short_circuit); } + uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { recursive_filter(r_filter_ids, r_filter_ids_length, root->right, enable_short_circuit); } + if (root->isOperator) { uint32_t* filtered_results = nullptr; if (root->filter_operator == AND) { @@ -2134,18 +2136,37 @@ void Index::recursive_filter(uint32_t*& filter_ids, } } +void Index::adaptive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const filter_tree_root, + const bool enable_short_circuit) const { + if (filter_tree_root == nullptr) { + return; + } + + if (filter_tree_root->metrics != nullptr && + (*filter_tree_root->metrics).filter_exp_count > 2 && + (*filter_tree_root->metrics).and_operator_count > 0 && + // If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree. + ((float) (*filter_tree_root->metrics).or_operator_count / (float) (*filter_tree_root->metrics).and_operator_count < 0.5)) { + rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root); + } else { + recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + } +} + void Index::do_filtering_with_lock(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* filter_tree_root) const { std::shared_lock lock(mutex); - recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, false); } void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, const std::string& reference_field_name) const { std::shared_lock lock(mutex); - recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); + adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); std::vector vector; vector.reserve(reference_index_ids.first); @@ -2653,7 +2674,7 @@ void Index::search(std::vector& field_query_tokens, const std::v std::shared_lock lock(mutex); - recursive_filter(filter_ids, filter_ids_length, filter_tree_root, true); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, true); if (filter_tree_root != nullptr && filter_ids_length == 0) { delete [] filter_ids; @@ -4749,7 +4770,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - recursive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root, true); + adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root, true); } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index d253bf4e..2039c5e7 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2590,5 +2590,17 @@ TEST_F(CollectionFilteringTest, ComplexFilterQuery) { ASSERT_STREQ(id.c_str(), result_id.c_str()); } +// results = coll->search("Jeremy", {"name"}, "years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))", +// {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); +// ASSERT_EQ(1, results["hits"].size()); +// +// ids = {"2"}; +// for (size_t i = 0; i < results["hits"].size(); i++) { +// nlohmann::json result = results["hits"].at(i); +// std::string result_id = result["document"]["id"]; +// std::string id = ids.at(i); +// ASSERT_STREQ(id.c_str(), result_id.c_str()); +// } + collectionManager.drop_collection("ComplexFilterQueryCollection"); } \ No newline at end of file From 0f2758fb97a80cc4b816b03e15c703b89e4c2287 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 30 Jan 2023 10:31:29 +0530 Subject: [PATCH 16/27] Fix `Index::get_filter_matches`. --- src/index.cpp | 6 ++---- test/collection_filtering_test.cpp | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index c4095454..9744eade 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2006,9 +2006,6 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectorleft, vec); - get_filter_matches(root->right, vec); - if (root->isOperator && root->filter_operator == OR) { uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; @@ -2034,7 +2031,8 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectormatch_index_ids.first, root); } else { - // malformed + get_filter_matches(root->left, vec); + get_filter_matches(root->right, vec); } } diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 2039c5e7..39a4c76a 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2590,17 +2590,17 @@ TEST_F(CollectionFilteringTest, ComplexFilterQuery) { ASSERT_STREQ(id.c_str(), result_id.c_str()); } -// results = coll->search("Jeremy", {"name"}, "years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))", -// {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); -// ASSERT_EQ(1, results["hits"].size()); -// -// ids = {"2"}; -// for (size_t i = 0; i < results["hits"].size(); i++) { -// nlohmann::json result = results["hits"].at(i); -// std::string result_id = result["document"]["id"]; -// std::string id = ids.at(i); -// ASSERT_STREQ(id.c_str(), result_id.c_str()); -// } + results = coll->search("Jeremy", {"name"}, "years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["hits"].size()); + + ids = {"2"}; + for (size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } collectionManager.drop_collection("ComplexFilterQueryCollection"); } \ No newline at end of file From 8abbced1bd8fef0f72416b11da38bbb8b1386512 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 30 Jan 2023 10:47:04 +0530 Subject: [PATCH 17/27] Refactor filtering logic. --- src/index.cpp | 96 +++++++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 9744eade..057c66b3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2006,34 +2006,38 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectorisOperator && root->filter_operator == OR) { - uint32_t* l_filter_ids = nullptr; - uint32_t l_filter_ids_length = 0; - if (root->left != nullptr) { - rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + if (root->isOperator) { + if (root->filter_operator == AND) { + get_filter_matches(root->left, vec); + get_filter_matches(root->right, vec); + } else { + uint32_t *l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + } + + uint32_t *r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + } + + root->match_index_ids.first = ArrayUtils::or_scalar( + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &(root->match_index_ids.second)); + + delete[] l_filter_ids; + delete[] r_filter_ids; + + vec.emplace_back(root->match_index_ids.first, root); } - uint32_t* r_filter_ids = nullptr; - uint32_t r_filter_ids_length = 0; - if (root->right != nullptr) { - rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right); - } - - root->match_index_ids.first = ArrayUtils::or_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &(root->match_index_ids.second)); - - delete[] l_filter_ids; - delete[] r_filter_ids; - - vec.emplace_back(root->match_index_ids.first, root); - } else if (root->left == nullptr && root->right == nullptr) { - do_filtering(root); - vec.emplace_back(root->match_index_ids.first, root); - } else { - get_filter_matches(root->left, vec); - get_filter_matches(root->right, vec); + return; } + + do_filtering(root); + vec.emplace_back(root->match_index_ids.first, root); } void evaluate_rearranged_filter_tree(uint32_t*& filter_ids, @@ -2094,21 +2098,22 @@ void Index::recursive_filter(uint32_t*& filter_ids, if (root == nullptr) { return; } - uint32_t* l_filter_ids = nullptr; - uint32_t l_filter_ids_length = 0; - if (root->left != nullptr) { - recursive_filter(l_filter_ids, l_filter_ids_length, root->left, - enable_short_circuit); - } - - uint32_t* r_filter_ids = nullptr; - uint32_t r_filter_ids_length = 0; - if (root->right != nullptr) { - recursive_filter(r_filter_ids, r_filter_ids_length, root->right, - enable_short_circuit); - } if (root->isOperator) { + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + recursive_filter(l_filter_ids, l_filter_ids_length, root->left, + enable_short_circuit); + } + + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + recursive_filter(r_filter_ids, r_filter_ids_length, root->right, + enable_short_circuit); + } + uint32_t* filtered_results = nullptr; if (root->filter_operator == AND) { filter_ids_length = ArrayUtils::and_scalar( @@ -2124,14 +2129,15 @@ void Index::recursive_filter(uint32_t*& filter_ids, delete[] r_filter_ids; filter_ids = filtered_results; - } else if (root->left == nullptr && root->right == nullptr) { - do_filtering(root); - filter_ids_length = root->match_index_ids.first; - filter_ids = root->match_index_ids.second; - root->match_index_ids.second = nullptr; - } else { - // malformed + return; } + + do_filtering(root); + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; + + // Prevents double deletion. We'll be deleting this array upstream and when the filter tree is destructed. + root->match_index_ids.second = nullptr; } void Index::adaptive_filter(uint32_t*& filter_ids, From 34f039e5844cfa6bc545d7478fa745bf4841e9fe Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 2 Feb 2023 11:23:09 +0530 Subject: [PATCH 18/27] Add `reference_fields` map in `Collection`. --- include/collection.h | 19 ++++++++- include/index.h | 2 +- src/collection.cpp | 74 +++++++++++++++++++++++++++++++--- src/field.cpp | 13 +++++- src/index.cpp | 50 ++--------------------- test/collection_join_test.cpp | 75 ++++++++++++++++++++--------------- 6 files changed, 143 insertions(+), 90 deletions(-) diff --git a/include/collection.h b/include/collection.h index 6d25bcfa..72ef1280 100644 --- a/include/collection.h +++ b/include/collection.h @@ -39,6 +39,13 @@ struct highlight_field_t { } }; +struct reference_pair { + std::string collection; + std::string field; + + reference_pair(std::string collection, std::string field) : collection(std::move(collection)), field(std::move(field)) {} +}; + class Collection { private: @@ -119,10 +126,14 @@ private: std::vector token_separators; - Index* index; - SynonymIndex* synonym_index; + // "field name" -> reference_pair + spp::sparse_hash_map reference_fields; + + // Keep index as the last field since it is initialized in the constructor via init_index(). Add a new field before it. + Index* index; + // methods std::string get_doc_id_key(const std::string & doc_id) const; @@ -282,6 +293,8 @@ public: static constexpr const char* COLLECTION_SYMBOLS_TO_INDEX = "symbols_to_index"; static constexpr const char* COLLECTION_SEPARATORS = "token_separators"; + static constexpr const char* REFERENCE_HELPER_FIELD_SUFFIX = "_sequence_id"; + // methods Collection() = delete; @@ -488,6 +501,8 @@ public: SynonymIndex* get_synonym_index(); + spp::sparse_hash_map get_reference_fields(); + // highlight ops static void highlight_text(const string& highlight_start_tag, const string& highlight_end_tag, diff --git a/include/index.h b/include/index.h index db16a24c..e142e6ae 100644 --- a/include/index.h +++ b/include/index.h @@ -724,7 +724,7 @@ public: void do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, - const std::string& reference_field_name) const; + const std::string& reference_helper_field_name) const; void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); diff --git a/src/collection.cpp b/src/collection.cpp index 7716b5ac..2e2b8b36 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -100,6 +100,56 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: // for UPSERT, EMPLACE or CREATE, if a document does not have an ID, we will treat it as a new doc uint32_t seq_id = get_next_seq_id(); document["id"] = std::to_string(seq_id); + + // Add reference helper fields in the document. + for (auto const& pair: reference_fields) { + auto field_name = pair.first; + auto optional = get_schema().at(field_name).optional; + if (!optional && document.count(field_name) != 1) { + return Option(400, "Missing the required reference field `" + field_name + + "` in the document."); + } else if (document.count(field_name) != 1) { + continue; + } + + auto reference_pair = pair.second; + auto reference_collection_name = reference_pair.collection; + auto reference_field_name = reference_pair.field; + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(reference_collection_name); + if (collection == nullptr) { + return Option(400, "Referenced collection `" + reference_collection_name + + "` not found."); + } + + if (collection->get_schema().count(reference_field_name) == 0) { + return Option(400, "Referenced field `" + reference_field_name + + "` not found in the collection `" + reference_collection_name + "`."); + } + + if (!collection->get_schema().at(reference_field_name).index) { + return Option(400, "Referenced field `" + reference_field_name + + "` in the collection `" + reference_collection_name + "` must be indexed."); + } + + std::vector> documents; + auto value = document[field_name].get(); + collection->get_filter_ids(reference_field_name + ":=" + value, documents); + + if (documents[0].first != 1) { + delete [] documents[0].second; + auto match = " `" + reference_field_name + ": " + value + "` "; + return Option(400, documents[0].first < 1 ? + "Referenced document having" + match + "not found in the collection `" + + reference_collection_name + "`." : + "Multiple documents having" + match + "found in the collection `" + + reference_collection_name + "`."); + } + + document[field_name + REFERENCE_HELPER_FIELD_SUFFIX] = *(documents[0].second); + delete [] documents[0].second; + } + return Option(doc_seq_id_t{seq_id, true}); } else { if(!document["id"].is_string()) { @@ -2368,11 +2418,10 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que std::shared_lock lock(mutex); std::string reference_field_name; - for (auto const& field: fields) { - if (!field.reference.empty() && - field.reference.find(collection_name) == 0 && - field.reference.find('.') == collection_name.size()) { - reference_field_name = field.name; + for (auto const& pair: reference_fields) { + auto reference_pair = pair.second; + if (reference_pair.collection == collection_name) { + reference_field_name = reference_pair.field; break; } } @@ -2390,7 +2439,8 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que return filter_op; } - reference_field_name += "_sequence_id"; + // Reference helper field has the sequence id of other collection's documents. + reference_field_name += REFERENCE_HELPER_FIELD_SUFFIX; index->do_reference_filtering_with_lock(reference_index_ids, filter_tree_root, reference_field_name); delete filter_tree_root; @@ -3400,6 +3450,10 @@ SynonymIndex* Collection::get_synonym_index() { return synonym_index; } +spp::sparse_hash_map Collection::get_reference_fields() { + return reference_fields; +} + Option Collection::persist_collection_meta() { // first compact nested fields (to keep only parents of expanded children) field::compact_nested_fields(nested_fields); @@ -4174,6 +4228,14 @@ Index* Collection::init_index() { if(field.nested) { nested_fields.emplace(field.name, field); } + + if(!field.reference.empty()) { + auto dot_index = field.reference.find('.'); + auto collection_name = field.reference.substr(0, dot_index); + auto field_name = field.reference.substr(dot_index + 1); + + reference_fields.emplace(field.name, reference_pair(collection_name, field_name)); + } } field::compact_nested_fields(nested_fields); diff --git a/src/field.cpp b/src/field.cpp index e5aa527c..c941ae30 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -727,6 +727,15 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso auto vec_dist = magic_enum::enum_cast(field_json[fields::vec_dist].get()).value(); + if (!field_json[fields::reference].get().empty()) { + std::vector tokens; + StringUtils::split(field_json[fields::reference].get(), tokens, "."); + + if (tokens.size() < 2) { + return Option(400, "Invalid reference `" + field_json[fields::reference].get() + "`."); + } + } + the_fields.emplace_back( field(field_json[fields::name], field_json[fields::type], field_json[fields::facet], field_json[fields::optional], field_json[fields::index], field_json[fields::locale], @@ -737,8 +746,8 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso if (!field_json[fields::reference].get().empty()) { the_fields.emplace_back( - field(field_json[fields::name].get() + "_sequence_id", "int64", false, - field_json[fields::optional], true) + field(field_json[fields::name].get() + Collection::REFERENCE_HELPER_FIELD_SUFFIX, + "int64", false, field_json[fields::optional], true) ); } diff --git a/src/index.cpp b/src/index.cpp index 057c66b3..bf4987fd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -431,50 +431,6 @@ Option Index::validate_index_in_memory(nlohmann::json& document, uint3 continue; } - if (!a_field.reference.empty()) { - // Add foo_sequence_id field in the document. - - std::vector tokens; - StringUtils::split(a_field.reference, tokens, "."); - - if (tokens.size() < 2) { - return Option<>(400, "Invalid reference `" + a_field.reference + "`."); - } - - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(tokens[0]); - if (collection == nullptr) { - return Option<>(400, "Referenced collection `" + tokens[0] + "` not found."); - } - - if (collection->get_schema().count(tokens[1]) == 0) { - return Option<>(400, "Referenced field `" + tokens[1] + "` not found in the collection `" - + tokens[0] + "`."); - } - - auto referenced_field_name = tokens[1]; - if (!collection->get_schema().at(referenced_field_name).index) { - return Option<>(400, "Referenced field `" + tokens[1] + "` in the collection `" - + tokens[0] + "` must be indexed."); - } - - std::vector> documents; - auto value = document[a_field.name].get(); - collection->get_filter_ids(referenced_field_name + ":=" + value, documents); - - if (documents[0].first != 1) { - delete [] documents[0].second; - auto match = " `" + referenced_field_name + "` = `" + value + "` "; - return Option<>(400, documents[0].first < 1 ? - "Referenced document having" + match + "not found in the collection `" + tokens[0] + "`." : - "Multiple documents having" + match + "found in the collection `" + tokens[0] + "`."); - } - - document[a_field.name + "_sequence_id"] = *(documents[0].second); - - delete [] documents[0].second; - } - if(document.count(field_name) == 0) { return Option<>(400, "Field `" + field_name + "` has been declared in the schema, " "but is not found in the document."); @@ -2168,7 +2124,7 @@ void Index::do_filtering_with_lock(uint32_t*& filter_ids, void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, - const std::string& reference_field_name) const { + const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); @@ -2178,8 +2134,8 @@ void Index::do_reference_filtering_with_lock(std::pair& ref for (uint32_t i = 0; i < reference_index_ids.first; i++) { auto filtered_doc_id = reference_index_ids.second[i]; - // Extract the sequence_id from the reference field. - vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id)); + // Extract the sequence id. + vector.push_back(sort_index.at(reference_helper_field_name)->at(filtered_doc_id)); } std::sort(vector.begin(), vector.end()); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index e7e5636f..7d45523a 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -78,20 +78,39 @@ TEST_F(CollectionJoinTest, SchemaReferenceField) { R"({ "name": "Customers", "fields": [ - {"name": "product_id", "type": "string", "reference": "Products.product_id"}, + {"name": "product_id", "type": "string", "reference": "foo"}, {"name": "customer_name", "type": "string"}, {"name": "product_price", "type": "float"} ] })"_json; collection_create_op = collectionManager.create_collection(schema_json); - ASSERT_TRUE(collection_create_op.ok()); + ASSERT_FALSE(collection_create_op.ok()); + ASSERT_EQ("Invalid reference `foo`.", collection_create_op.error()); + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "product_id", "type": "string", "reference": "Products.product_id"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"} + ] + })"_json; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); auto collection = collection_create_op.get(); auto schema = collection->get_schema(); - ASSERT_EQ(schema.at("customer_name").reference, ""); - ASSERT_EQ(schema.at("product_id").reference, "Products.product_id"); + ASSERT_EQ(schema.count("customer_name"), 1); + ASSERT_TRUE(schema.at("customer_name").reference.empty()); + ASSERT_EQ(schema.count("product_id"), 1); + ASSERT_FALSE(schema.at("product_id").reference.empty()); + + auto reference_fields = collection->get_reference_fields(); + ASSERT_EQ(reference_fields.count("product_id"), 1); + ASSERT_EQ(reference_fields.at("product_id").collection, "Products"); + ASSERT_EQ(reference_fields.at("product_id").field, "product_id"); // Add a `foo_sequence_id` field in the schema for `foo` reference field. ASSERT_EQ(schema.count("product_id_sequence_id"), 1); @@ -108,11 +127,12 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { {"name": "customer_id", "type": "string"}, {"name": "customer_name", "type": "string"}, {"name": "product_price", "type": "float"}, - {"name": "product_id", "type": "string", "reference": "foo"} + {"name": "reference_id", "type": "string", "reference": "products.product_id"} ] })"_json; auto collection_create_op = collectionManager.create_collection(customers_schema_json); ASSERT_TRUE(collection_create_op.ok()); + auto customer_collection = collection_create_op.get(); nlohmann::json customer_json = R"({ "customer_id": "customer_a", @@ -120,27 +140,17 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { "product_price": 143, "product_id": "a" })"_json; - - auto customer_collection = collection_create_op.get(); auto add_doc_op = customer_collection->add(customer_json.dump()); + ASSERT_FALSE(add_doc_op.ok()); - ASSERT_EQ("Invalid reference `foo`.", add_doc_op.error()); - collectionManager.drop_collection("Customers"); + ASSERT_EQ("Missing the required reference field `reference_id` in the document.", add_doc_op.error()); - customers_schema_json = - R"({ - "name": "Customers", - "fields": [ - {"name": "customer_id", "type": "string"}, - {"name": "customer_name", "type": "string"}, - {"name": "product_price", "type": "float"}, - {"name": "product_id", "type": "string", "reference": "products.product_id"} - ] - })"_json; - collection_create_op = collectionManager.create_collection(customers_schema_json); - ASSERT_TRUE(collection_create_op.ok()); - - customer_collection = collection_create_op.get(); + customer_json = R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 143, + "reference_id": "a" + })"_json; add_doc_op = customer_collection->add(customer_json.dump()); ASSERT_FALSE(add_doc_op.ok()); ASSERT_EQ("Referenced collection `products` not found.", add_doc_op.error()); @@ -153,7 +163,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { {"name": "customer_id", "type": "string"}, {"name": "customer_name", "type": "string"}, {"name": "product_price", "type": "float"}, - {"name": "product_id", "type": "string", "reference": "Products.id"} + {"name": "reference_id", "type": "string", "reference": "Products.id"} ] })"_json; collection_create_op = collectionManager.create_collection(customers_schema_json); @@ -184,7 +194,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { {"name": "customer_id", "type": "string"}, {"name": "customer_name", "type": "string"}, {"name": "product_price", "type": "float"}, - {"name": "product_id", "type": "string", "reference": "Products.product_id"} + {"name": "reference_id", "type": "string", "reference": "Products.product_id"} ] })"_json; collection_create_op = collectionManager.create_collection(customers_schema_json); @@ -209,7 +219,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { ASSERT_TRUE(collection_create_op.ok()); add_doc_op = customer_collection->add(customer_json.dump()); - ASSERT_EQ("Referenced document having `product_id` = `a` not found in the collection `Products`.", add_doc_op.error()); + ASSERT_EQ("Referenced document having `product_id: a` not found in the collection `Products`.", add_doc_op.error()); std::vector products = { R"({ @@ -231,9 +241,9 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { ASSERT_TRUE(add_op.ok()); } - customer_json["product_id"] = "product_a"; + customer_json["reference_id"] = "product_a"; add_doc_op = customer_collection->add(customer_json.dump()); - ASSERT_EQ("Multiple documents having `product_id` = `product_a` found in the collection `Products`.", add_doc_op.error()); + ASSERT_EQ("Multiple documents having `product_id: product_a` found in the collection `Products`.", add_doc_op.error()); collectionManager.drop_collection("Products"); products[1]["product_id"] = "product_b"; @@ -255,6 +265,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { } ASSERT_TRUE(add_op.ok()); } + collectionManager.drop_collection("Customers"); customers_schema_json = R"({ @@ -263,7 +274,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { {"name": "customer_id", "type": "string"}, {"name": "customer_name", "type": "string"}, {"name": "product_price", "type": "float"}, - {"name": "product_id", "type": "string", "reference": "Products.product_id"} + {"name": "reference_id", "type": "string", "reference": "Products.product_id"} ] })"_json; collection_create_op = collectionManager.create_collection(customers_schema_json); @@ -272,12 +283,12 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { customer_collection = collection_create_op.get(); add_doc_op = customer_collection->add(customer_json.dump()); ASSERT_TRUE(add_doc_op.ok()); - ASSERT_EQ(customer_collection->get("0").get().count("product_id_sequence_id"), 1); + ASSERT_EQ(customer_collection->get("0").get().count("reference_id_sequence_id"), 1); nlohmann::json document; // Referenced document's sequence_id must be valid. auto get_op = collectionManager.get_collection("Products")->get_document_from_store( - customer_collection->get("0").get()["product_id_sequence_id"].get(), + customer_collection->get("0").get()["reference_id_sequence_id"].get(), document); ASSERT_TRUE(get_op.ok()); ASSERT_EQ(document.count("product_id"), 1); @@ -393,4 +404,4 @@ TEST_F(CollectionJoinTest, FilterByReferenceField) { ASSERT_EQ(1, result["found"].get()); ASSERT_EQ(1, result["hits"].size()); ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); -} \ No newline at end of file +} From 16d6a5cbf05e4c0eb6abc1e37add76ff0f4153eb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 3 Feb 2023 14:30:17 +0530 Subject: [PATCH 19/27] Fix double locking of collection mutex. --- include/index.h | 24 ++++-- src/collection.cpp | 4 +- src/index.cpp | 61 +++++++------- test/collection_join_test.cpp | 147 +++++++++++++++++++++++++++++++++- 4 files changed, 198 insertions(+), 38 deletions(-) diff --git a/include/index.h b/include/index.h index e142e6ae..e0875935 100644 --- a/include/index.h +++ b/include/index.h @@ -484,21 +484,27 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(filter_node_t* const root) const; + void do_filtering(filter_node_t* const root, const std::string& collection_name) const; - void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const; + void rearranging_recursive_filter (uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const; void recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root, - const bool enable_short_circuit = false) const; + const std::string& collection_name) const; void adaptive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const filter_tree_root, - const bool enable_short_circuit = false) const; + const std::string& collection_name = "") const; - void get_filter_matches(filter_node_t* const root, std::vector>& vec) const; + void get_filter_matches(filter_node_t* const root, + std::vector>& vec, + const std::string& collection_name) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -656,7 +662,7 @@ public: // Public operations - void run_search(search_args* search_params); + void run_search(search_args* search_params, const std::string& collection_name); void search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, @@ -679,7 +685,8 @@ public: size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, enable_t split_join_tokens, - const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold) const; + const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, + const std::string& collection_name) const; void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name); @@ -720,7 +727,8 @@ public: void do_filtering_with_lock( uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t* filter_tree_root) const; + filter_node_t* filter_tree_root, + const std::string& collection_name) const; void do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, diff --git a/src/collection.cpp b/src/collection.cpp index 2e2b8b36..cba0d545 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1449,7 +1449,7 @@ Option Collection::search(const std::string & raw_query, filter_curated_hits, split_join_tokens, vector_query, facet_sample_percent, facet_sample_threshold); - index->run_search(search_params); + index->run_search(search_params, name); // for grouping we have to re-aggregate @@ -2405,7 +2405,7 @@ Option Collection::get_filter_ids(const std::string & filter_query, uint32_t* filter_ids = nullptr; uint32_t filter_ids_len = 0; - index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root); + index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root, name); index_ids.emplace_back(filter_ids_len, filter_ids); delete filter_tree_root; diff --git a/src/index.cpp b/src/index.cpp index bf4987fd..a14b5274 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1617,7 +1617,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(filter_node_t* const root) const { +void Index::do_filtering(filter_node_t* const root, const std::string& collection_name) const { // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; @@ -1628,7 +1628,7 @@ void Index::do_filtering(filter_node_t* const root) const { auto collection = cm.get_collection(a_filter.referenced_collection_name); auto op = collection->get_reference_filter_ids(a_filter.field_name, - cm.get_collection_with_id(collection_id)->get_name(), + collection_name, root->match_index_ids); if (!op.ok()) { return; @@ -1957,26 +1957,29 @@ void Index::do_filtering(filter_node_t* const root) const { LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::get_filter_matches(filter_node_t* const root, std::vector>& vec) const { +void Index::get_filter_matches(filter_node_t* const root, + std::vector>& vec, + const std::string& collection_name) const { if (root == nullptr) { return; } if (root->isOperator) { if (root->filter_operator == AND) { - get_filter_matches(root->left, vec); - get_filter_matches(root->right, vec); + get_filter_matches(root->left, vec, collection_name); + get_filter_matches(root->right, vec, collection_name); } else { uint32_t *l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left, collection_name); } uint32_t *r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right, collection_name); } root->match_index_ids.first = ArrayUtils::or_scalar( @@ -1992,7 +1995,7 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectormatch_index_ids.first, root); } @@ -2031,9 +2034,12 @@ void evaluate_rearranged_filter_tree(uint32_t*& filter_ids, } } -void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const { +void Index::rearranging_recursive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const { std::vector> vec; - get_filter_matches(root, vec); + get_filter_matches(root, vec, collection_name); std::sort(vec.begin(), vec.end(), [](const std::pair& lhs, const std::pair& rhs) { @@ -2050,7 +2056,7 @@ void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter void Index::recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root, - const bool enable_short_circuit) const { + const std::string& collection_name) const { if (root == nullptr) { return; } @@ -2059,15 +2065,13 @@ void Index::recursive_filter(uint32_t*& filter_ids, uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - recursive_filter(l_filter_ids, l_filter_ids_length, root->left, - enable_short_circuit); + recursive_filter(l_filter_ids, l_filter_ids_length, root->left,collection_name); } uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - recursive_filter(r_filter_ids, r_filter_ids_length, root->right, - enable_short_circuit); + recursive_filter(r_filter_ids, r_filter_ids_length, root->right,collection_name); } uint32_t* filtered_results = nullptr; @@ -2088,7 +2092,7 @@ void Index::recursive_filter(uint32_t*& filter_ids, return; } - do_filtering(root); + do_filtering(root, collection_name); filter_ids_length = root->match_index_ids.first; filter_ids = root->match_index_ids.second; @@ -2099,7 +2103,7 @@ void Index::recursive_filter(uint32_t*& filter_ids, void Index::adaptive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const filter_tree_root, - const bool enable_short_circuit) const { + const std::string& collection_name) const { if (filter_tree_root == nullptr) { return; } @@ -2109,24 +2113,25 @@ void Index::adaptive_filter(uint32_t*& filter_ids, (*filter_tree_root->metrics).and_operator_count > 0 && // If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree. ((float) (*filter_tree_root->metrics).or_operator_count / (float) (*filter_tree_root->metrics).and_operator_count < 0.5)) { - rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root); + rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } else { - recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } } void Index::do_filtering_with_lock(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t* filter_tree_root) const { + filter_node_t* filter_tree_root, + const std::string& collection_name) const { std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); + adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root); std::vector vector; vector.reserve(reference_index_ids.first); @@ -2142,7 +2147,7 @@ void Index::do_reference_filtering_with_lock(std::pair& ref std::copy(vector.begin(), vector.end(), reference_index_ids.second); } -void Index::run_search(search_args* search_params) { +void Index::run_search(search_args* search_params, const std::string& collection_name) { search(search_params->field_query_tokens, search_params->search_fields, search_params->match_type, @@ -2175,7 +2180,8 @@ void Index::run_search(search_args* search_params) { search_params->split_join_tokens, search_params->vector_query, search_params->facet_sample_percent, - search_params->facet_sample_threshold); + search_params->facet_sample_threshold, + collection_name); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -2625,7 +2631,8 @@ void Index::search(std::vector& field_query_tokens, const std::v const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, const enable_t split_join_tokens, const vector_query_t& vector_query, - size_t facet_sample_percent, size_t facet_sample_threshold) const { + size_t facet_sample_percent, size_t facet_sample_threshold, + const std::string& collection_name) const { // process the filters @@ -2634,7 +2641,7 @@ void Index::search(std::vector& field_query_tokens, const std::v std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, true); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); if (filter_tree_root != nullptr && filter_ids_length == 0) { delete [] filter_ids; @@ -4730,7 +4737,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root, true); + adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root); } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 7d45523a..ab1936d2 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -299,7 +299,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { collectionManager.drop_collection("Products"); } -TEST_F(CollectionJoinTest, FilterByReferenceField) { +TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) { auto schema_json = R"({ "name": "Products", @@ -404,4 +404,149 @@ TEST_F(CollectionJoinTest, FilterByReferenceField) { ASSERT_EQ(1, result["found"].get()); ASSERT_EQ(1, result["hits"].size()); ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); + +// collectionManager.drop_collection("Customers"); +// collectionManager.drop_collection("Products"); } + +TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) { + auto schema_json = + R"({ + "name": "Users", + "fields": [ + {"name": "user_id", "type": "string"}, + {"name": "user_name", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "user_id": "user_a", + "user_name": "Roshan" + })"_json, + R"({ + "user_id": "user_b", + "user_name": "Ruby" + })"_json, + R"({ + "user_id": "user_c", + "user_name": "Joe" + })"_json, + R"({ + "user_id": "user_d", + "user_name": "Aby" + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Repos", + "fields": [ + {"name": "repo_id", "type": "string"}, + {"name": "repo_content", "type": "string"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "repo_content": "body1" + })"_json, + R"({ + "repo_id": "repo_b", + "repo_content": "body2" + })"_json, + R"({ + "repo_id": "repo_c", + "repo_content": "body3" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Links", + "fields": [ + {"name": "repo_id", "type": "string", "reference": "Repos.repo_id"}, + {"name": "user_id", "type": "string", "reference": "Users.user_id"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_a", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_d" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_d" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + auto coll = collectionManager.get_collection("Users"); + + // Search for users linked to repo_b + auto result = coll->search("R", {"user_name"}, "$Links(repo_id:=repo_b)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get(); + + ASSERT_EQ(2, result["found"].get()); + ASSERT_EQ(2, result["hits"].size()); + ASSERT_EQ("user_b", result["hits"][0]["document"]["user_id"].get()); + ASSERT_EQ("user_a", result["hits"][1]["document"]["user_id"].get()); + +// collectionManager.drop_collection("Users"); +// collectionManager.drop_collection("Repos"); +// collectionManager.drop_collection("Links"); +} \ No newline at end of file From da1b327749d4092b3351303a4c78567ccbc26bd5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 3 Feb 2023 17:09:24 +0530 Subject: [PATCH 20/27] Refactor `rearranging_recursive_filter`. --- include/field.h | 2 - include/index.h | 11 ++-- src/index.cpp | 130 +++++++++++++++++------------------------------- 3 files changed, 53 insertions(+), 90 deletions(-) diff --git a/include/field.h b/include/field.h index e189648d..6f5fe485 100644 --- a/include/field.h +++ b/include/field.h @@ -542,7 +542,6 @@ struct filter_node_t { bool isOperator; filter_node_t* left; filter_node_t* right; - std::pair match_index_ids = {0, nullptr}; filter_tree_metrics* metrics = nullptr; filter_node_t(filter filter_exp) @@ -561,7 +560,6 @@ struct filter_node_t { ~filter_node_t() { delete metrics; - delete[] match_index_ids.second; delete left; delete right; } diff --git a/include/index.h b/include/index.h index e0875935..78288464 100644 --- a/include/index.h +++ b/include/index.h @@ -484,7 +484,10 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(filter_node_t* const root, const std::string& collection_name) const; + void do_filtering(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const; void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, @@ -501,9 +504,9 @@ private: filter_node_t* const filter_tree_root, const std::string& collection_name = "") const; - void get_filter_matches(filter_node_t* const root, - std::vector>& vec, + void rearrange_filter_tree(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, const std::string& collection_name) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, diff --git a/src/index.cpp b/src/index.cpp index a14b5274..cb8b5562 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1617,7 +1617,10 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(filter_node_t* const root, const std::string& collection_name) const { +void Index::do_filtering(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const { // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; @@ -1627,12 +1630,17 @@ void Index::do_filtering(filter_node_t* const root, const std::string& collectio auto& cm = CollectionManager::get_instance(); auto collection = cm.get_collection(a_filter.referenced_collection_name); + std::pair reference_index_ids; auto op = collection->get_reference_filter_ids(a_filter.field_name, collection_name, - root->match_index_ids); + reference_index_ids); if (!op.ok()) { return; } + + filter_ids_length = reference_index_ids.first; + filter_ids = reference_index_ids.second; + return; } @@ -1645,9 +1653,17 @@ void Index::do_filtering(filter_node_t* const root, const std::string& collectio std::sort(result_ids.begin(), result_ids.end()); - root->match_index_ids.second = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), root->match_index_ids.second); - root->match_index_ids.first = result_ids.size(); + if (filter_ids_length == 0) { + filter_ids = new uint32[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_ids); + filter_ids_length = result_ids.size(); + } else { + uint32_t* filtered_results = nullptr; + filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0], + result_ids.size(), &filtered_results); + delete[] filter_ids; + filter_ids = filtered_results; + } return; } @@ -1947,8 +1963,8 @@ void Index::do_filtering(filter_node_t* const root, const std::string& collectio result_ids_len = to_include_ids_len; } - root->match_index_ids.first = result_ids_len; - root->match_index_ids.second = result_ids; + filter_ids = result_ids; + filter_ids_length = result_ids_len; /*long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() @@ -1957,100 +1973,51 @@ void Index::do_filtering(filter_node_t* const root, const std::string& collectio LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::get_filter_matches(filter_node_t* const root, - std::vector>& vec, - const std::string& collection_name) const { +void Index::rearrange_filter_tree(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const { if (root == nullptr) { return; } if (root->isOperator) { - if (root->filter_operator == AND) { - get_filter_matches(root->left, vec, collection_name); - get_filter_matches(root->right, vec, collection_name); - } else { - uint32_t *l_filter_ids = nullptr; - uint32_t l_filter_ids_length = 0; - if (root->left != nullptr) { - rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left, collection_name); - } - - uint32_t *r_filter_ids = nullptr; - uint32_t r_filter_ids_length = 0; - if (root->right != nullptr) { - rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right, collection_name); - } - - root->match_index_ids.first = ArrayUtils::or_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &(root->match_index_ids.second)); - - delete[] l_filter_ids; - delete[] r_filter_ids; - - vec.emplace_back(root->match_index_ids.first, root); - } - - return; - } - - do_filtering(root, collection_name); - vec.emplace_back(root->match_index_ids.first, root); -} - -void evaluate_rearranged_filter_tree(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - std::vector>& vec, - size_t& index) { - if (root == nullptr) { - return; - } - - if (root->isOperator && root->filter_operator == AND) { uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - evaluate_rearranged_filter_tree(l_filter_ids, l_filter_ids_length, root->left, vec, index); + rearrange_filter_tree(l_filter_ids, l_filter_ids_length,root->left, collection_name); } uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - evaluate_rearranged_filter_tree(r_filter_ids, r_filter_ids_length, root->right, vec, index); + rearrange_filter_tree(r_filter_ids, r_filter_ids_length, root->right, collection_name); } - root->match_index_ids.first = ArrayUtils::and_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &(root->match_index_ids.second)); + if (root->filter_operator == AND) { + filter_ids_length = std::min(l_filter_ids_length, r_filter_ids_length); + } else { + filter_ids_length = l_filter_ids_length + r_filter_ids_length; + } - filter_ids_length = root->match_index_ids.first; - filter_ids = root->match_index_ids.second; - } else { - filter_ids_length = vec[index].first; - filter_ids = vec[index].second->match_index_ids.second; - index++; + if (l_filter_ids_length > r_filter_ids_length) { + std::swap(root->left, root->right); + } + + delete[] l_filter_ids; + delete[] r_filter_ids; + return; } + + do_filtering(filter_ids, filter_ids_length, root, collection_name); } void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root, const std::string& collection_name) const { - std::vector> vec; - get_filter_matches(root, vec, collection_name); - - std::sort(vec.begin(), vec.end(), - [](const std::pair& lhs, const std::pair& rhs) { - return lhs.first < rhs.first; - }); - - size_t index = 0; - evaluate_rearranged_filter_tree(filter_ids, filter_ids_length, root, vec, index); - - // To disable deletion of filter_ids when filter tree is destructed. - root->match_index_ids.second = nullptr; + rearrange_filter_tree(filter_ids, filter_ids_length, root, collection_name); + recursive_filter(filter_ids, filter_ids_length, root, collection_name); } void Index::recursive_filter(uint32_t*& filter_ids, @@ -2092,12 +2059,7 @@ void Index::recursive_filter(uint32_t*& filter_ids, return; } - do_filtering(root, collection_name); - filter_ids_length = root->match_index_ids.first; - filter_ids = root->match_index_ids.second; - - // Prevents double deletion. We'll be deleting this array upstream and when the filter tree is destructed. - root->match_index_ids.second = nullptr; + do_filtering(filter_ids, filter_ids_length, root, collection_name); } void Index::adaptive_filter(uint32_t*& filter_ids, From 961e4330cf916c4ae776289bb9596582b0d7f53d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 7 Feb 2023 10:53:18 +0530 Subject: [PATCH 21/27] Fix tests. --- test/collection_join_test.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index ab1936d2..98e7663f 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -299,7 +299,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { collectionManager.drop_collection("Products"); } -TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) { +TEST_F(CollectionJoinTest, FilterByReference_SingleMatch) { auto schema_json = R"({ "name": "Products", @@ -377,7 +377,7 @@ TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) { ASSERT_TRUE(add_op.ok()); } - auto coll = collectionManager.get_collection("Products"); + auto coll = collectionManager.get_collection_unsafe("Products"); auto search_op = coll->search("s", {"product_name"}, "$foo:=customer_a", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD); ASSERT_FALSE(search_op.ok()); @@ -405,11 +405,11 @@ TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) { ASSERT_EQ(1, result["hits"].size()); ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); -// collectionManager.drop_collection("Customers"); -// collectionManager.drop_collection("Products"); + collectionManager.drop_collection("Customers"); + collectionManager.drop_collection("Products"); } -TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) { +TEST_F(CollectionJoinTest, FilterByReference_MultipleMatch) { auto schema_json = R"({ "name": "Users", @@ -535,7 +535,7 @@ TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) { ASSERT_TRUE(add_op.ok()); } - auto coll = collectionManager.get_collection("Users"); + auto coll = collectionManager.get_collection_unsafe("Users"); // Search for users linked to repo_b auto result = coll->search("R", {"user_name"}, "$Links(repo_id:=repo_b)", {}, {}, {0}, @@ -546,7 +546,7 @@ TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) { ASSERT_EQ("user_b", result["hits"][0]["document"]["user_id"].get()); ASSERT_EQ("user_a", result["hits"][1]["document"]["user_id"].get()); -// collectionManager.drop_collection("Users"); -// collectionManager.drop_collection("Repos"); -// collectionManager.drop_collection("Links"); + collectionManager.drop_collection("Users"); + collectionManager.drop_collection("Repos"); + collectionManager.drop_collection("Links"); } \ No newline at end of file From b1ef695461536e2ddf3d87c89c1244f9d4f17c69 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 9 Feb 2023 11:50:58 +0530 Subject: [PATCH 22/27] Reference `include_fields`. --- include/collection.h | 10 ++- include/string_utils.h | 2 + src/collection.cpp | 142 +++++++++++++++++++++++++++--- src/collection_manager.cpp | 6 ++ src/string_utils.cpp | 36 ++++++++ test/collection_join_test.cpp | 160 ++++++++++++++++++++++++++++++++++ 6 files changed, 343 insertions(+), 13 deletions(-) diff --git a/include/collection.h b/include/collection.h index 72ef1280..e6f54e3e 100644 --- a/include/collection.h +++ b/include/collection.h @@ -356,8 +356,10 @@ public: static void remove_flat_fields(nlohmann::json& document); - static void prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, - const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0); + static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, + const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, + const uint32_t doc_sequence_id = 0, const std::string& collection_name = "", + const std::map& reference_filter_map = {}); const Index* _get_index() const; @@ -449,6 +451,8 @@ public: Option get_filter_ids(const std::string & filter_query, std::vector>& index_ids) const; + Option get_reference_field(const std::string & collection_name) const; + Option get_reference_filter_ids(const std::string & filter_query, const std::string & collection_name, std::pair& reference_index_ids) const; @@ -515,8 +519,8 @@ public: void process_highlight_fields(const std::vector& search_fields, const std::vector& raw_search_fields, - const tsl::htrie_set& exclude_fields, const tsl::htrie_set& include_fields, + const tsl::htrie_set& exclude_fields, const std::vector& highlight_field_names, const std::vector& highlight_full_field_names, const std::vector& infixes, diff --git a/include/string_utils.h b/include/string_utils.h index 3ba28925..5ad2c9a6 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -379,4 +379,6 @@ struct StringUtils { static size_t get_num_chars(const std::string& text); static Option tokenize_filter_query(const std::string& filter_query, std::queue& tokens); + + static Option split_include_fields(const std::string& include_fields, std::vector& tokens); }; diff --git a/src/collection.cpp b/src/collection.cpp index cba0d545..331669f4 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -946,6 +946,12 @@ Option Collection::extract_field_name(const std::string& field_name, const bool extract_only_string_fields, const bool enable_nested_fields, const bool handle_wildcard) { + // Reference to other collection + if (field_name[0] == '$') { + processed_search_fields.push_back(field_name); + return Option(true); + } + if(field_name == "id") { processed_search_fields.push_back(field_name); return Option(true); @@ -991,6 +997,19 @@ Option Collection::extract_field_name(const std::string& field_name, return Option(true); } +void get_reference_filters(filter_node_t const* const root, std::map& reference_filter_map) { + if (root == nullptr) { + return; + } + + if (!root->isOperator && !root->filter_exp.referenced_collection_name.empty()) { + reference_filter_map[root->filter_exp.referenced_collection_name] = root->filter_exp.field_name; + } + + get_reference_filters(root->left, reference_filter_map); + get_reference_filters(root->right, reference_filter_map); +} + Option Collection::search(const std::string & raw_query, const std::vector& raw_search_fields, const std::string & filter_query, const std::vector& facet_fields, @@ -1747,7 +1766,24 @@ Option Collection::search(const std::string & raw_query, } remove_flat_fields(document); - prune_doc(document, include_fields_full, exclude_fields_full); + auto doc_id_op = doc_id_to_seq_id(document["id"].get()); + if (!doc_id_op.ok()) { + return Option(doc_id_op.code(), doc_id_op.error()); + } + + std::map reference_filter_map; + get_reference_filters(filter_tree_root, reference_filter_map); + auto prune_op = prune_doc(document, + include_fields_full, + exclude_fields_full, + "", + 0, + doc_id_op.get(), + name, + reference_filter_map); + if (!prune_op.ok()) { + return Option(prune_op.code(), prune_op.error()); + } wrapper_doc["document"] = document; wrapper_doc["highlight"] = highlight_res; @@ -2412,9 +2448,7 @@ Option Collection::get_filter_ids(const std::string & filter_query, return Option(true); } -Option Collection::get_reference_filter_ids(const std::string & filter_query, - const std::string & collection_name, - std::pair& reference_index_ids) const { +Option Collection::get_reference_field(const std::string & collection_name) const { std::shared_lock lock(mutex); std::string reference_field_name; @@ -2427,10 +2461,23 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que } if (reference_field_name.empty()) { - return Option(400, "Could not find any field in `" + name + "` referencing the collection `" - + collection_name + "`."); + return Option(400, "Could not find any field in `" + name + "` referencing the collection `" + + collection_name + "`."); } + return Option(reference_field_name); +} + +Option Collection::get_reference_filter_ids(const std::string & filter_query, + const std::string & collection_name, + std::pair& reference_index_ids) const { + auto reference_field_op = get_reference_field(collection_name); + if (!reference_field_op.ok()) { + return Option(reference_field_op.code(), reference_field_op.error()); + } + + std::shared_lock lock(mutex); + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; Option filter_op = filter::parse_filter_query(filter_query, search_schema, @@ -2440,8 +2487,8 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que } // Reference helper field has the sequence id of other collection's documents. - reference_field_name += REFERENCE_HELPER_FIELD_SUFFIX; - index->do_reference_filtering_with_lock(reference_index_ids, filter_tree_root, reference_field_name); + auto field_name = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX; + index->do_reference_filtering_with_lock(reference_index_ids, filter_tree_root, field_name); delete filter_tree_root; return Option(true); @@ -3684,10 +3731,12 @@ void Collection::remove_flat_fields(nlohmann::json& document) { } } -void Collection::prune_doc(nlohmann::json& doc, +Option Collection::prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, - const std::string& parent_name, size_t depth) { + const std::string& parent_name, size_t depth, + const uint32_t doc_sequence_id, const std::string& collection_name, + const std::map& reference_filter_map) { // doc can only be an object auto it = doc.begin(); while(it != doc.end()) { @@ -3763,6 +3812,79 @@ void Collection::prune_doc(nlohmann::json& doc, it++; } + + auto reference_it = include_names.equal_prefix_range("$"); + for (auto reference = reference_it.first; reference != reference_it.second; reference++) { + auto ref = reference.key(); + size_t parenthesis_index = ref.find('('); + + auto ref_collection_name = ref.substr(1, parenthesis_index - 1); + auto reference_fields = ref.substr(parenthesis_index + 1, ref.size() - parenthesis_index - 2); + + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(ref_collection_name); + if (collection == nullptr) { + return Option(400, "Referenced collection `" + ref_collection_name + "` not found."); + } + + std::vector include_fields_vec; + StringUtils::split(reference_fields, include_fields_vec, ","); + + spp::sparse_hash_set include_fields, exclude_fields; + include_fields.insert(include_fields_vec.begin(), include_fields_vec.end()); + + tsl::htrie_set include_fields_full, exclude_fields_full; + auto include_exclude_op = collection->populate_include_exclude_fields(include_fields, exclude_fields, + include_fields_full, exclude_fields_full); + if (!include_exclude_op.ok()) { + return include_exclude_op; + } + + auto reference_field_op = collection->get_reference_field(collection_name); + if (!reference_field_op.ok()) { + return Option(reference_field_op.code(), reference_field_op.error()); + } + + std::vector> documents; + auto filter = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX + ":=" + std::to_string(doc_sequence_id); + if (reference_filter_map.count(ref_collection_name) > 0) { + filter += "&&"; + filter += reference_filter_map.at(ref_collection_name); + } + auto filter_op = collection->get_filter_ids(filter, documents); + if (!filter_op.ok()) { + return filter_op; + } + + if (documents[0].first == 0) { + continue; + } + + std::vector reference_docs; + reference_docs.reserve(documents[0].first); + for (size_t i = 0; i < documents[0].first; i++) { + auto doc_seq_id = documents[0].second[i]; + + nlohmann::json ref_doc; + auto get_doc_op = collection->get_document_from_store(doc_seq_id, ref_doc); + if (!get_doc_op.ok()) { + return get_doc_op; + } + + auto prune_op = prune_doc(ref_doc, include_fields_full, exclude_fields_full); + if (!prune_op.ok()) { + return prune_op; + } + + reference_docs.push_back(ref_doc); + } + + for (const auto &ref_doc: reference_docs) { + doc.update(ref_doc); + } + } + + return Option(true); } Option Collection::validate_alter_payload(nlohmann::json& schema_changes, diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 0fb4d5be..f10aacf2 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -907,6 +907,12 @@ Option CollectionManager::do_search(std::map& re if(key == FACET_BY){ StringUtils::split_facet(val, *find_str_list_it->second); } + else if(key == INCLUDE_FIELDS){ + auto op = StringUtils::split_include_fields(val, *find_str_list_it->second); + if (!op.ok()) { + return op; + } + } else{ StringUtils::split(val, *find_str_list_it->second, ","); } diff --git a/src/string_utils.cpp b/src/string_utils.cpp index a9409400..9fa65b4c 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -453,6 +453,42 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, return Option(true); } +Option StringUtils::split_include_fields(const std::string& include_fields, std::vector& tokens) { + size_t start = 0, end = 0, size = include_fields.size(); + std::string include_field; + + while (true) { + auto range_pos = include_fields.find('$', start); + auto comma_pos = include_fields.find(',', start); + + if (range_pos == std::string::npos && comma_pos == std::string::npos) { + if (start < size - 1) { + tokens.push_back(include_fields.substr(start, size - start)); + } + break; + } else if (range_pos < comma_pos) { + end = include_fields.find(')', range_pos); + if (end == std::string::npos || end < include_fields.find('(', range_pos)) { + return Option(400, "Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`."); + } + + include_field = include_fields.substr(range_pos, (end - range_pos) + 1); + } else { + end = comma_pos; + include_field = include_fields.substr(start, end - start); + } + + include_field = trim(include_field); + if (!include_field.empty()) { + tokens.push_back(include_field); + } + + start = end + 1; + } + + return Option(true); +} + /*size_t StringUtils::unicode_length(const std::string& bytes) { std::wstring_convert, char32_t> utf8conv; return utf8conv.from_bytes(bytes).size(); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 98e7663f..e96f7f57 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -549,4 +549,164 @@ TEST_F(CollectionJoinTest, FilterByReference_MultipleMatch) { collectionManager.drop_collection("Users"); collectionManager.drop_collection("Repos"); collectionManager.drop_collection("Links"); +} + +TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string"}, + {"name": "product_description", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_b", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Customers", + "fields": [ + {"name": "customer_id", "type": "string"}, + {"name": "customer_name", "type": "string"}, + {"name": "product_price", "type": "float"}, + {"name": "product_id", "type": "string", "reference": "Products.product_id"} + ] + })"_json; + documents = { + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 143, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_a", + "customer_name": "Joe", + "product_price": 73.5, + "product_id": "product_b" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 75, + "product_id": "product_a" + })"_json, + R"({ + "customer_id": "customer_b", + "customer_name": "Dan", + "product_price": 140, + "product_id": "product_b" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + std::map req_params = { + {"collection", "Products"}, + {"q", "s"}, + {"query_by", "product_name"}, + {"filter_by", "$Customers(customer_id:=customer_a && product_price:<100)"}, + }; + + nlohmann::json embedded_params; + std::string json_res; + auto now_ts = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count(); + + req_params["include_fields"] = "$foo.bar"; + auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error()); + + req_params["include_fields"] = "$foo(bar"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error()); + + req_params["include_fields"] = "$foo(bar)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ("Referenced collection `foo` not found.", search_op.error()); + + req_params["include_fields"] = "$Customers(bar)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + nlohmann::json res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(0, res_obj["hits"][0]["document"].size()); + + req_params["include_fields"] = "$Customers(product_price)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); + ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + + req_params["include_fields"] = "$Customers(product_price, customer_id)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); + ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("customer_id")); + ASSERT_EQ("customer_a", res_obj["hits"][0]["document"].at("customer_id")); + + req_params["include_fields"] = "*, $Customers(product_price, customer_id)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + // 3 fields in Products document and 2 fields from Customers document + ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); + + req_params["include_fields"] = "*, $Customers(product*)"; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + // 3 fields in Products document and 2 fields from Customers document + ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id_sequence_id")); } \ No newline at end of file From 57908965ae2f488312a51ae99102e1711ebda019 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 9 Feb 2023 12:25:16 +0530 Subject: [PATCH 23/27] fix memory leak. --- src/collection.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/collection.cpp b/src/collection.cpp index 331669f4..1f27616d 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -3857,6 +3857,7 @@ Option Collection::prune_doc(nlohmann::json& doc, } if (documents[0].first == 0) { + delete[] documents[0].second; continue; } @@ -3879,6 +3880,8 @@ Option Collection::prune_doc(nlohmann::json& doc, reference_docs.push_back(ref_doc); } + delete[] documents[0].second; + for (const auto &ref_doc: reference_docs) { doc.update(ref_doc); } From 5eda7668b966de627cce6e0cd9ced834291259b9 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 14 Feb 2023 14:28:39 +0530 Subject: [PATCH 24/27] Refactor fuzzy search restrictions. --- include/art.h | 7 +- src/art.cpp | 113 ++++++++++++++++++++++------- src/index.cpp | 85 +++++----------------- test/art_test.cpp | 116 +++++++++++++++++++----------- test/collection_specific_test.cpp | 2 + test/collection_test.cpp | 2 +- 6 files changed, 185 insertions(+), 140 deletions(-) diff --git a/include/art.h b/include/art.h index 0502641c..a9715fac 100644 --- a/include/art.h +++ b/include/art.h @@ -276,9 +276,10 @@ int art_iter_prefix(art_tree *t, const unsigned char *prefix, int prefix_len, ar * Returns leaves that match a given string within a fuzzy distance of max_cost. */ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const int max_words, const token_ordering token_order, const bool prefix, - const uint32_t *filter_ids, size_t filter_ids_length, - std::vector &results, const std::set& exclude_leaves = {}); + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + const uint32_t *filter_ids, const size_t filter_ids_length, + std::vector &results, std::set& exclude_leaves); void encode_int32(int32_t n, unsigned char *chars); diff --git a/src/art.cpp b/src/art.cpp index f778eace..835f77b9 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -21,6 +21,7 @@ #include #include "art.h" #include "logger.h" +#include "array_utils.h" /** * Macros to manipulate pointer tags @@ -940,10 +941,69 @@ void* art_delete(art_tree *t, const unsigned char *key, int key_len) { return child->max_token_count; }*/ +const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, + const uint32_t* filter_ids, const size_t filter_ids_length, + size_t& prev_token_doc_ids_len) { + + art_leaf* prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + if(prev_token.empty() || !prev_leaf) { + prev_token_doc_ids_len = filter_ids_length; + return filter_ids; + } + + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + uint32_t* prev_token_doc_ids = nullptr; + + if(filter_ids_length != 0) { + prev_token_doc_ids_len = ArrayUtils::and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), + filter_ids, filter_ids_length, + &prev_token_doc_ids); + } else { + prev_token_doc_ids_len = prev_leaf_ids.size(); + prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; + std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); + } + + return prev_token_doc_ids; +} + +bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, + const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, + std::set& exclude_leaves, const art_leaf* exact_leaf, + std::vector& results) { + + if(leaf == exact_leaf) { + return false; + } + + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + if(exclude_leaves.count(tok) != 0) { + return false; + } + + if(allowed_doc_ids_len != 0) { + if(!posting_t::contains_atleast_one(leaf->values, allowed_doc_ids, + allowed_doc_ids_len)) { + return false; + } + } + + exclude_leaves.emplace(tok); + results.push_back(leaf); + + return true; +} + int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, - const uint32_t* filter_ids, size_t filter_ids_length, - const std::set& exclude_leaves, const art_leaf* exact_leaf, - std::vector& results) { + const art_leaf* exact_leaf, + const bool last_token, const std::string& prev_token, + const uint32_t* allowed_doc_ids, size_t allowed_doc_ids_len, + const art_tree* t, std::set& exclude_leaves, std::vector& results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); @@ -972,25 +1032,8 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r if (IS_LEAF(n)) { art_leaf *l = (art_leaf *) LEAF_RAW(n); //LOG(INFO) << "END LEAF SCORE: " << l->max_score; - - if(filter_ids_length == 0) { - std::string tok(reinterpret_cast(l->key), l->key_len - 1); - if(exclude_leaves.count(tok) != 0 || l == exact_leaf) { - continue; - } - results.push_back(l); - } else { - // we will push leaf only if filter matches with leaf IDs - bool found_atleast_one = posting_t::contains_atleast_one(l->values, filter_ids, filter_ids_length); - if(found_atleast_one) { - std::string tok(reinterpret_cast(l->key), l->key_len - 1); - if(exclude_leaves.count(tok) != 0 || l == exact_leaf) { - continue; - } - results.push_back(l); - } - } - + validate_and_add_leaf(l, last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + exclude_leaves, exact_leaf, results); continue; } @@ -1491,9 +1534,10 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * * Returns leaves that match a given string within a fuzzy distance of max_cost. */ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const int max_words, const token_ordering token_order, const bool prefix, - const uint32_t *filter_ids, size_t filter_ids_length, - std::vector &results, const std::set& exclude_leaves) { + const size_t max_words, const token_ordering token_order, const bool prefix, + bool last_token, const std::string& prev_token, + const uint32_t *filter_ids, const size_t filter_ids_length, + std::vector &results, std::set& exclude_leaves) { std::vector nodes; int irow[term_len + 1]; @@ -1525,8 +1569,15 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_ids, filter_ids_length, + allowed_doc_ids_len); + for(auto node: nodes) { - art_topk_iter(node, token_order, max_words, filter_ids, filter_ids_length, exclude_leaves, exact_leaf, results); + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); } if(token_order == FREQUENCY) { @@ -1536,7 +1587,11 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, } if(exact_leaf && min_cost == 0) { - results.insert(results.begin(), exact_leaf); + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } } if(results.size() > max_words) { @@ -1551,6 +1606,10 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, << ", filter_ids_length: " << filter_ids_length; }*/ + if(allowed_doc_ids != filter_ids) { + delete [] allowed_doc_ids; + } + return 0; } diff --git a/src/index.cpp b/src/index.cpp index 6f9e8d7b..60a56a6d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3232,12 +3232,12 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } //LOG(INFO) << "Searching for field: " << the_field.name << ", found token:" << token; + const auto& prev_token = last_token ? token_candidates_vec.back().candidates[0] : ""; std::vector field_leaves; - int max_words = 100000; art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], max_words, token_order, prefix_search, - filter_ids, filter_ids_length, field_leaves, unique_tokens); + costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, + last_token, prev_token, filter_ids, filter_ids_length, field_leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3248,60 +3248,17 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, continue; } - uint32_t* prev_token_doc_ids = nullptr; // documents that contain the previous token - size_t prev_token_doc_ids_len = 0; - - if(last_token) { - auto& prev_token = token_candidates_vec.back().candidates[0]; - art_leaf* prev_leaf = static_cast( - art_search(search_index.at(the_field.name), - reinterpret_cast(prev_token.c_str()), - prev_token.size() + 1)); - - if(!prev_leaf) { - continue; - } - - std::vector prev_leaf_ids; - posting_t::merge({prev_leaf->values}, prev_leaf_ids); - - if(filter_ids_length != 0) { - prev_token_doc_ids_len = ArrayUtils::and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), - filter_ids, filter_ids_length, - &prev_token_doc_ids); - } else { - prev_token_doc_ids_len = prev_leaf_ids.size(); - prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; - std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); - } - } - for(size_t i = 0; i < field_leaves.size(); i++) { auto leaf = field_leaves[i]; std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); - if(unique_tokens.count(tok) == 0) { - if(last_token) { - if(!posting_t::contains_atleast_one(leaf->values, prev_token_doc_ids, - prev_token_doc_ids_len)) { - continue; - } - } - - unique_tokens.emplace(tok); - leaf_tokens.push_back(tok); - } - - if(leaf_tokens.size() >= max_candidates) { - token_cost_cache.emplace(token_cost_hash, leaf_tokens); - delete [] prev_token_doc_ids; - prev_token_doc_ids = nullptr; - goto token_done; - } + leaf_tokens.push_back(tok); } token_cost_cache.emplace(token_cost_hash, leaf_tokens); - delete [] prev_token_doc_ids; - prev_token_doc_ids = nullptr; + + if(leaf_tokens.size() >= max_candidates) { + goto token_done; + } } if(last_token && leaf_tokens.size() < max_candidates) { @@ -3330,10 +3287,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } std::vector field_leaves; - int max_words = 100000; art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], max_words, token_order, prefix_search, - filter_ids, filter_ids_length, field_leaves, unique_tokens); + costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, + false, "", filter_ids, filter_ids_length, field_leaves, unique_tokens); if(field_leaves.empty()) { // look at the next field @@ -3343,23 +3299,14 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, for(size_t i = 0; i < field_leaves.size(); i++) { auto leaf = field_leaves[i]; std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); - if(unique_tokens.count(tok) == 0) { - if(!posting_t::contains_atleast_one(leaf->values, &prev_token_doc_ids[0], - prev_token_doc_ids.size())) { - continue; - } - - unique_tokens.emplace(tok); - leaf_tokens.push_back(tok); - } - - if(leaf_tokens.size() >= max_candidates) { - token_cost_cache.emplace(token_cost_hash, leaf_tokens); - goto token_done; - } + leaf_tokens.push_back(tok); } token_cost_cache.emplace(token_cost_hash, leaf_tokens); + + if(leaf_tokens.size() >= max_candidates) { + goto token_done; + } } } } @@ -4741,7 +4688,7 @@ void Index::search_field(const uint8_t & field_id, // need less candidates for filtered searches since we already only pick tokens with results art_fuzzy_search(search_index.at(field_name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - filter_ids, filter_ids_length, leaves, unique_tokens); + false, "", filter_ids, filter_ids_length, leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); diff --git a/test/art_test.cpp b/test/art_test.cpp index 0236b5e1..f8414653 100644 --- a/test/art_test.cpp +++ b/test/art_test.cpp @@ -18,6 +18,8 @@ art_document get_document(uint32_t id) { return document; } +std::set exclude_leaves; + TEST(ArtTest, test_art_init_and_destroy) { art_tree t; int res = art_tree_init(&t); @@ -587,22 +589,25 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf) { EXPECT_EQ(1, posting_t::first_id(l->values)); std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *) implement_key, strlen(implement_key) + 1, 0, 0, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) implement_key, strlen(implement_key) + 1, 0, 0, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); const char* implement_key_typo1 = "implment"; const char* implement_key_typo2 = "implwnent"; leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 0, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 0, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo2, strlen(implement_key_typo2) + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo2, strlen(implement_key_typo2) + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -623,11 +628,12 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_prefix) { std::vector leaves; std::string term = "aplication"; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -645,7 +651,7 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_qlen_greater_than_key) { std::string term = "starkbin"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); } @@ -660,11 +666,12 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_non_prefix) { std::string term = "spz"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -682,7 +689,7 @@ TEST(ArtTest, test_art_prefix_larger_than_key) { std::string term = "earrings"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); res = art_tree_destroy(&t); @@ -706,7 +713,7 @@ TEST(ArtTest, test_art_fuzzy_search_prefix_token_ordering) { } std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *) "e", 1, 0, 0, 3, MAX_SCORE, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) "e", 1, 0, 0, 3, MAX_SCORE, true, false, "", nullptr, 0, leaves, exclude_leaves); std::string first_key(reinterpret_cast(leaves[0]->key), leaves[0]->key_len - 1); ASSERT_EQ("e", first_key); @@ -718,7 +725,8 @@ TEST(ArtTest, test_art_fuzzy_search_prefix_token_ordering) { ASSERT_EQ("elephant", third_key); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "enter", 5, 1, 1, 3, MAX_SCORE, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "enter", 5, 1, 1, 3, MAX_SCORE, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_TRUE(leaves.empty()); res = art_tree_destroy(&t); @@ -747,56 +755,65 @@ TEST(ArtTest, test_art_fuzzy_search) { auto begin = std::chrono::high_resolution_clock::now(); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "pltinum", strlen("pltinum"), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "pltinum", strlen("pltinum"), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); ASSERT_STREQ("platinumsmith", (const char *)leaves.at(0)->key); ASSERT_STREQ("platinum", (const char *)leaves.at(1)->key); leaves.clear(); + exclude_leaves.clear(); // extra char - art_fuzzy_search(&t, (const unsigned char *) "higghliving", strlen("higghliving") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) "higghliving", strlen("higghliving") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("highliving", (const char *)leaves.at(0)->key); // transpose leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "zymosthneic", strlen("zymosthneic") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "zymosthneic", strlen("zymosthneic") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("zymosthenic", (const char *)leaves.at(0)->key); // transpose + missing leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("dacrycystalgia", (const char *)leaves.at(0)->key); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 1, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 1, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("dacrycystalgia", (const char *)leaves.at(0)->key); // missing char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "gaberlunze", strlen("gaberlunze") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "gaberlunze", strlen("gaberlunze") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("gaberlunzie", (const char *)leaves.at(0)->key); // substituted char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "eacemiferous", strlen("eacemiferous") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "eacemiferous", strlen("eacemiferous") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("racemiferous", (const char *)leaves.at(0)->key); // missing char + extra char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "Sarbruckken", strlen("Sarbruckken") + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "Sarbruckken", strlen("Sarbruckken") + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("Saarbrucken", (const char *)leaves.at(0)->key); // multiple matching results leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(10, leaves.size()); std::set expected_words = {"town", "sown", "mown", "lown", "howl", "howk", "howe", "how", "horn", "hoon"}; @@ -809,23 +826,28 @@ TEST(ArtTest, test_art_fuzzy_search) { // fuzzy prefix search leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "lionhear", strlen("lionhear"), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "lionhear", strlen("lionhear"), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(3, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "lineage", strlen("lineage"), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "lineage", strlen("lineage"), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "liq", strlen("liq"), 0, 0, 50, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "liq", strlen("liq"), 0, 0, 50, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(39, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "antitraditiana", strlen("antitraditiana"), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "antitraditiana", strlen("antitraditiana"), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "antisocao", strlen("antisocao"), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "antisocao", strlen("antisocao"), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(6, leaves.size()); long long int timeMillis = std::chrono::duration_cast( @@ -855,7 +877,7 @@ TEST(ArtTest, test_art_fuzzy_search_unicode_chars) { EXPECT_EQ(1, posting_t::first_id(l->values)); std::vector leaves; - art_fuzzy_search(&t, (unsigned char *)key, strlen(key), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (unsigned char *)key, strlen(key), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); } @@ -879,7 +901,7 @@ TEST(ArtTest, test_art_fuzzy_search_extra_chars) { const char* query = "abbreviation"; std::vector leaves; - art_fuzzy_search(&t, (unsigned char *)query, strlen(query), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (unsigned char *)query, strlen(query), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -918,15 +940,16 @@ TEST(ArtTest, test_art_search_sku_like_tokens) { for (const auto &key : keys) { std::vector leaves; art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size()+1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -970,14 +993,17 @@ TEST(ArtTest, test_art_search_ill_like_tokens) { std::make_pair("ice", 2), }; + std::string key = "input"; + for (const auto &key : keys) { art_leaf* l = (art_leaf *) art_search(&t, (const unsigned char *)key.c_str(), key.size()+1); ASSERT_FALSE(l == nullptr); EXPECT_EQ(1, posting_t::num_ids(l->values)); std::vector leaves; + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); if(key_to_count.count(key) != 0) { ASSERT_EQ(key_to_count[key], leaves.size()); @@ -987,10 +1013,14 @@ TEST(ArtTest, test_art_search_ill_like_tokens) { } leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size()+1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); + if(leaves.size() != 1) { + LOG(INFO) << key; + } ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -1022,8 +1052,9 @@ TEST(ArtTest, test_art_search_ill_like_tokens2) { EXPECT_EQ(1, posting_t::num_ids(l->values)); std::vector leaves; + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); if(key == "illustration") { ASSERT_EQ(2, leaves.size()); @@ -1033,10 +1064,11 @@ TEST(ArtTest, test_art_search_ill_like_tokens2) { } leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size() + 1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -1059,12 +1091,12 @@ TEST(ArtTest, test_art_search_roche_chews) { std::string term = "chews"; std::vector leaves; art_fuzzy_search(&t, (const unsigned char*)term.c_str(), term.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); art_fuzzy_search(&t, (const unsigned char*)keys[0].c_str(), keys[0].size() + 1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); @@ -1091,14 +1123,15 @@ TEST(ArtTest, test_art_search_raspberry) { std::string q_raspberries = "raspberries"; art_fuzzy_search(&t, (const unsigned char*)q_raspberries.c_str(), q_raspberries.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); leaves.clear(); + exclude_leaves.clear(); std::string q_raspberry = "raspberry"; art_fuzzy_search(&t, (const unsigned char*)q_raspberry.c_str(), q_raspberry.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); res = art_tree_destroy(&t); @@ -1124,13 +1157,16 @@ TEST(ArtTest, test_art_search_highliving) { std::string query = "higghliving"; art_fuzzy_search(&t, (const unsigned char*)query.c_str(), query.size() + 1, 0, 1, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); + exclude_leaves.clear(); + exclude_leaves.clear(); + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)query.c_str(), query.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); diff --git a/test/collection_specific_test.cpp b/test/collection_specific_test.cpp index de658b37..befa767b 100644 --- a/test/collection_specific_test.cpp +++ b/test/collection_specific_test.cpp @@ -203,6 +203,8 @@ TEST_F(CollectionSpecificTest, ExactSingleFieldMatch) { spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 10).get(); + LOG(INFO) << results; + ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index f241e135..055acc75 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -473,7 +473,7 @@ TEST_F(CollectionTest, TextContainingAnActualTypo) { ASSERT_EQ(4, results["hits"].size()); ASSERT_EQ(11, results["found"].get()); - std::vector ids = {"19", "22", "6", "13"}; + std::vector ids = {"19", "6", "21", "22"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); From 34bc8f0d0461dfedcc740fd36ee2a74f6a6a313b Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 14 Feb 2023 16:07:20 +0530 Subject: [PATCH 25/27] Enable search cutoff for art search. --- src/art.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/art.cpp b/src/art.cpp index 835f77b9..40b028a3 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1017,6 +1017,8 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r q.push(root); + size_t num_processed = 0; + while(!q.empty() && results.size() < max_results*4) { art_node *n = (art_node *) q.top(); q.pop(); @@ -1034,6 +1036,13 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r //LOG(INFO) << "END LEAF SCORE: " << l->max_score; validate_and_add_leaf(l, last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, exclude_leaves, exact_leaf, results); + + if (++num_processed % 1024 == 0 && (microseconds( + std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { + search_cutoff = true; + break; + } + continue; } From c2211e914dfc3e472cd2d7d313b53e42a6e7b65e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 15 Feb 2023 16:48:44 +0530 Subject: [PATCH 26/27] temp. --- include/collection.h | 10 +- include/field.h | 20 +++ include/index.h | 57 +++--- include/topster.h | 2 + src/collection.cpp | 124 +++++-------- src/core_api.cpp | 12 +- src/index.cpp | 323 ++++++++++++++++++++-------------- test/collection_join_test.cpp | 106 +++++------ test/core_api_utils_test.cpp | 28 ++- 9 files changed, 364 insertions(+), 318 deletions(-) diff --git a/include/collection.h b/include/collection.h index e6f54e3e..f638b392 100644 --- a/include/collection.h +++ b/include/collection.h @@ -358,8 +358,7 @@ public: static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, - const uint32_t doc_sequence_id = 0, const std::string& collection_name = "", - const std::map& reference_filter_map = {}); + const reference_filter_result_t* reference_filter_result = nullptr); const Index* _get_index() const; @@ -448,14 +447,13 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0) const; - Option get_filter_ids(const std::string & filter_query, - std::vector>& index_ids) const; + Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; Option get_reference_field(const std::string & collection_name) const; Option get_reference_filter_ids(const std::string & filter_query, - const std::string & collection_name, - std::pair& reference_index_ids) const; + filter_result_t& filter_result, + const std::string & collection_name) const; Option validate_reference_filter(const std::string& filter_query) const; diff --git a/include/field.h b/include/field.h index 6f5fe485..af609fa1 100644 --- a/include/field.h +++ b/include/field.h @@ -565,6 +565,26 @@ struct filter_node_t { } }; +struct reference_filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + + ~reference_filter_result_t() { + delete[] docs; + } +}; + +struct filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + reference_filter_result_t* reference_filter_result = nullptr; + + ~filter_result_t() { + delete[] docs; + delete[] reference_filter_result; + } +}; + namespace sort_field_const { static const std::string name = "name"; static const std::string order = "order"; diff --git a/include/index.h b/include/index.h index 78288464..79e0b352 100644 --- a/include/index.h +++ b/include/index.h @@ -484,30 +484,25 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const; + Option do_filtering(filter_node_t* const root, + filter_result_t& result, + const std::string& collection_name = "") const; - void rearranging_recursive_filter (uint32_t*& filter_ids, + Option rearranging_recursive_filter (filter_node_t* const filter_tree_root, + filter_result_t& result, + const std::string& collection_name = "") const; + + Option recursive_filter(filter_node_t* const root, + filter_result_t& result, + const std::string& collection_name = "") const; + + Option adaptive_filter(filter_node_t* const filter_tree_root, + filter_result_t& result, + const std::string& collection_name = "") const; + + Option rearrange_filter_tree(filter_node_t* const root, uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const; - - void recursive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const; - - void adaptive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const filter_tree_root, - const std::string& collection_name = "") const; - - void rearrange_filter_tree(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const; + const std::string& collection_name = "") const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -665,9 +660,9 @@ public: // Public operations - void run_search(search_args* search_params, const std::string& collection_name); + Option run_search(search_args* search_params, const std::string& collection_name); - void search(std::vector& field_query_tokens, const std::vector& the_fields, + Option search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, @@ -727,15 +722,13 @@ public: art_leaf* get_token_leaf(const std::string & field_name, const unsigned char* token, uint32_t token_len); - void do_filtering_with_lock( - uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* filter_tree_root, - const std::string& collection_name) const; + Option do_filtering_with_lock(filter_node_t* const filter_tree_root, + filter_result_t& filter_result, + const std::string& collection_name = "") const; - void do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t* filter_tree_root, - const std::string& reference_helper_field_name) const; + Option do_reference_filtering_with_lock(filter_node_t* const filter_tree_root, + filter_result_t& filter_result, + const std::string & reference_helper_field_name) const; void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); diff --git a/include/topster.h b/include/topster.h index 191be35e..3dd999ac 100644 --- a/include/topster.h +++ b/include/topster.h @@ -5,6 +5,7 @@ #include #include #include +#include struct KV { int8_t match_score_index{}; @@ -13,6 +14,7 @@ struct KV { uint64_t key{}; uint64_t distinct_key{}; int64_t scores[3]{}; // match score + 2 custom attributes + reference_filter_result_t* reference_filter_result; // to be used only in final aggregation uint64_t* query_indices = nullptr; diff --git a/src/collection.cpp b/src/collection.cpp index 1f27616d..b4fbe96a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -132,22 +132,20 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: "` in the collection `" + reference_collection_name + "` must be indexed."); } - std::vector> documents; auto value = document[field_name].get(); - collection->get_filter_ids(reference_field_name + ":=" + value, documents); + filter_result_t filter_result; + collection->get_filter_ids(reference_field_name + ":=" + value, filter_result); - if (documents[0].first != 1) { - delete [] documents[0].second; + if (filter_result.count != 1) { auto match = " `" + reference_field_name + ": " + value + "` "; - return Option(400, documents[0].first < 1 ? + return Option(400, filter_result.count < 1 ? "Referenced document having" + match + "not found in the collection `" + reference_collection_name + "`." : "Multiple documents having" + match + "found in the collection `" + reference_collection_name + "`."); } - document[field_name + REFERENCE_HELPER_FIELD_SUFFIX] = *(documents[0].second); - delete [] documents[0].second; + document[field_name + REFERENCE_HELPER_FIELD_SUFFIX] = filter_result.docs[0]; } return Option(doc_seq_id_t{seq_id, true}); @@ -442,15 +440,15 @@ Option Collection::update_matching_filter(const std::string& fil delete iter_upper_bound; delete it; } else { - std::vector> filter_ids; - auto filter_ids_op = get_filter_ids(_filter_query, filter_ids); + filter_result_t filter_result; + auto filter_ids_op = get_filter_ids(_filter_query, filter_result); if(!filter_ids_op.ok()) { return Option(filter_ids_op.code(), filter_ids_op.error()); } - for (size_t i = 0; i < filter_ids[0].first;) { - for (int buffer_counter = 0; buffer_counter < batch_size && i < filter_ids[0].first;) { - uint32_t seq_id = *(filter_ids[0].second + i++); + for (size_t i = 0; i < filter_result.count;) { + for (int buffer_counter = 0; buffer_counter < batch_size && i < filter_result.count;) { + uint32_t seq_id = filter_result.docs[i++]; nlohmann::json existing_document; auto get_doc_op = get_document_from_store(get_seq_id_key(seq_id), existing_document); @@ -467,8 +465,6 @@ Option Collection::update_matching_filter(const std::string& fil docs_updated_count += res["num_imported"].get(); buffer.clear(); } - - delete [] filter_ids[0].second; } nlohmann::json resp_summary; @@ -997,19 +993,6 @@ Option Collection::extract_field_name(const std::string& field_name, return Option(true); } -void get_reference_filters(filter_node_t const* const root, std::map& reference_filter_map) { - if (root == nullptr) { - return; - } - - if (!root->isOperator && !root->filter_exp.referenced_collection_name.empty()) { - reference_filter_map[root->filter_exp.referenced_collection_name] = root->filter_exp.field_name; - } - - get_reference_filters(root->left, reference_filter_map); - get_reference_filters(root->right, reference_filter_map); -} - Option Collection::search(const std::string & raw_query, const std::vector& raw_search_fields, const std::string & filter_query, const std::vector& facet_fields, @@ -1468,7 +1451,10 @@ Option Collection::search(const std::string & raw_query, filter_curated_hits, split_join_tokens, vector_query, facet_sample_percent, facet_sample_threshold); - index->run_search(search_params, name); + auto search_op = index->run_search(search_params, name); + if (!search_op.ok()) { + return Option(search_op.code(), search_op.error()); + } // for grouping we have to re-aggregate @@ -1771,16 +1757,12 @@ Option Collection::search(const std::string & raw_query, return Option(doc_id_op.code(), doc_id_op.error()); } - std::map reference_filter_map; - get_reference_filters(filter_tree_root, reference_filter_map); auto prune_op = prune_doc(document, - include_fields_full, - exclude_fields_full, - "", - 0, - doc_id_op.get(), - name, - reference_filter_map); + include_fields_full, + exclude_fields_full, + "", + 0, + field_order_kv->reference_filter_result); if (!prune_op.ok()) { return Option(prune_op.code(), prune_op.error()); } @@ -2426,23 +2408,18 @@ void Collection::populate_result_kvs(Topster *topster, std::vector Collection::get_filter_ids(const std::string & filter_query, - std::vector>& index_ids) const { +Option Collection::get_filter_ids(const std::string& filter_query, filter_result_t& filter_result) const { std::shared_lock lock(mutex); const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; Option filter_op = filter::parse_filter_query(filter_query, search_schema, store, doc_id_prefix, filter_tree_root); - if(!filter_op.ok()) { return filter_op; } - uint32_t* filter_ids = nullptr; - uint32_t filter_ids_len = 0; - index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root, name); - index_ids.emplace_back(filter_ids_len, filter_ids); + index->do_filtering_with_lock(filter_tree_root, filter_result, name); delete filter_tree_root; return Option(true); @@ -2469,8 +2446,8 @@ Option Collection::get_reference_field(const std::string & collecti } Option Collection::get_reference_filter_ids(const std::string & filter_query, - const std::string & collection_name, - std::pair& reference_index_ids) const { + filter_result_t& filter_result, + const std::string & collection_name) const { auto reference_field_op = get_reference_field(collection_name); if (!reference_field_op.ok()) { return Option(reference_field_op.code(), reference_field_op.error()); @@ -2480,15 +2457,18 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; - Option filter_op = filter::parse_filter_query(filter_query, search_schema, - store, doc_id_prefix, filter_tree_root); - if(!filter_op.ok()) { - return filter_op; + Option parse_op = filter::parse_filter_query(filter_query, search_schema, + store, doc_id_prefix, filter_tree_root); + if(!parse_op.ok()) { + return parse_op; } // Reference helper field has the sequence id of other collection's documents. auto field_name = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX; - index->do_reference_filtering_with_lock(reference_index_ids, filter_tree_root, field_name); + auto filter_op = index->do_reference_filtering_with_lock(filter_tree_root, filter_result, field_name); + if (!filter_op.ok()) { + return filter_op; + } delete filter_tree_root; return Option(true); @@ -3732,11 +3712,10 @@ void Collection::remove_flat_fields(nlohmann::json& document) { } Option Collection::prune_doc(nlohmann::json& doc, - const tsl::htrie_set& include_names, - const tsl::htrie_set& exclude_names, - const std::string& parent_name, size_t depth, - const uint32_t doc_sequence_id, const std::string& collection_name, - const std::map& reference_filter_map) { + const tsl::htrie_set& include_names, + const tsl::htrie_set& exclude_names, + const std::string& parent_name, size_t depth, + const reference_filter_result_t* reference_filter_result) { // doc can only be an object auto it = doc.begin(); while(it != doc.end()) { @@ -3813,6 +3792,10 @@ Option Collection::prune_doc(nlohmann::json& doc, it++; } + if (reference_filter_result == nullptr) { + return Option(true); + } + auto reference_it = include_names.equal_prefix_range("$"); for (auto reference = reference_it.first; reference != reference_it.second; reference++) { auto ref = reference.key(); @@ -3840,31 +3823,10 @@ Option Collection::prune_doc(nlohmann::json& doc, return include_exclude_op; } - auto reference_field_op = collection->get_reference_field(collection_name); - if (!reference_field_op.ok()) { - return Option(reference_field_op.code(), reference_field_op.error()); - } - - std::vector> documents; - auto filter = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX + ":=" + std::to_string(doc_sequence_id); - if (reference_filter_map.count(ref_collection_name) > 0) { - filter += "&&"; - filter += reference_filter_map.at(ref_collection_name); - } - auto filter_op = collection->get_filter_ids(filter, documents); - if (!filter_op.ok()) { - return filter_op; - } - - if (documents[0].first == 0) { - delete[] documents[0].second; - continue; - } - std::vector reference_docs; - reference_docs.reserve(documents[0].first); - for (size_t i = 0; i < documents[0].first; i++) { - auto doc_seq_id = documents[0].second[i]; + reference_docs.reserve(reference_filter_result->count); + for (size_t i = 0; i < reference_filter_result->count; i++) { + auto doc_seq_id = reference_filter_result->docs[i]; nlohmann::json ref_doc; auto get_doc_op = collection->get_document_from_store(doc_seq_id, ref_doc); @@ -3880,8 +3842,6 @@ Option Collection::prune_doc(nlohmann::json& doc, reference_docs.push_back(ref_doc); } - delete[] documents[0].second; - for (const auto &ref_doc: reference_docs) { doc.update(ref_doc); } diff --git a/src/core_api.cpp b/src/core_api.cpp index 48839d3d..021e0b0c 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -634,7 +634,8 @@ bool get_export_documents(const std::shared_ptr& req, const std::share export_state->iter_upper_bound = new rocksdb::Slice(export_state->iter_upper_bound_key); export_state->it = collectionManager.get_store()->scan(seq_id_prefix, export_state->iter_upper_bound); } else { - auto filter_ids_op = collection->get_filter_ids(simple_filter_query, export_state->index_ids); + filter_result_t filter_result; + auto filter_ids_op = collection->get_filter_ids(simple_filter_query, filter_result); if(!filter_ids_op.ok()) { res->set(filter_ids_op.code(), filter_ids_op.error()); @@ -644,6 +645,9 @@ bool get_export_documents(const std::shared_ptr& req, const std::share return false; } + export_state->index_ids.emplace_back(filter_result.count, filter_result.docs); + filter_result.docs = nullptr; + for(size_t i=0; iindex_ids.size(); i++) { export_state->offsets.push_back(0); } @@ -1082,7 +1086,8 @@ bool del_remove_documents(const std::shared_ptr& req, const std::share // destruction of data is managed by req destructor req->data = deletion_state; - auto filter_ids_op = collection->get_filter_ids(simple_filter_query, deletion_state->index_ids); + filter_result_t filter_result; + auto filter_ids_op = collection->get_filter_ids(simple_filter_query, filter_result); if(!filter_ids_op.ok()) { res->set(filter_ids_op.code(), filter_ids_op.error()); @@ -1092,6 +1097,9 @@ bool del_remove_documents(const std::shared_ptr& req, const std::share return false; } + deletion_state->index_ids.emplace_back(filter_result.count, filter_result.docs); + filter_result.docs = nullptr; + for(size_t i=0; iindex_ids.size(); i++) { deletion_state->offsets.push_back(0); } diff --git a/src/index.cpp b/src/index.cpp index cb8b5562..5fa32b35 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1617,10 +1617,9 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const { +Option Index::do_filtering(filter_node_t* const root, + filter_result_t& result, + const std::string& collection_name) const { // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; @@ -1629,19 +1628,17 @@ void Index::do_filtering(uint32_t*& filter_ids, // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. auto& cm = CollectionManager::get_instance(); auto collection = cm.get_collection(a_filter.referenced_collection_name); - - std::pair reference_index_ids; - auto op = collection->get_reference_filter_ids(a_filter.field_name, - collection_name, - reference_index_ids); - if (!op.ok()) { - return; + if (collection == nullptr) { + return Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + } + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + result, + collection_name); + if (!reference_filter_op.ok()) { + return reference_filter_op; } - filter_ids_length = reference_index_ids.first; - filter_ids = reference_index_ids.second; - - return; + return Option(true); } if (a_filter.field_name == "id") { @@ -1653,19 +1650,11 @@ void Index::do_filtering(uint32_t*& filter_ids, std::sort(result_ids.begin(), result_ids.end()); - if (filter_ids_length == 0) { - filter_ids = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), filter_ids); - filter_ids_length = result_ids.size(); - } else { - uint32_t* filtered_results = nullptr; - filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0], - result_ids.size(), &filtered_results); - delete[] filter_ids; - filter_ids = filtered_results; - } + result.docs = new uint32[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), result.docs); + result.count = result_ids.size(); - return; + return Option(true); } bool has_search_index = search_index.count(a_filter.field_name) != 0 || @@ -1673,7 +1662,7 @@ void Index::do_filtering(uint32_t*& filter_ids, geopoint_index.count(a_filter.field_name) != 0; if (!has_search_index) { - return; + return Option(true); } field f = search_schema.at(a_filter.field_name); @@ -1963,9 +1952,10 @@ void Index::do_filtering(uint32_t*& filter_ids, result_ids_len = to_include_ids_len; } - filter_ids = result_ids; - filter_ids_length = result_ids_len; + result.docs = result_ids; + result.count = result_ids_len; + return Option(true); /*long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - begin).count(); @@ -1973,25 +1963,28 @@ void Index::do_filtering(uint32_t*& filter_ids, LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::rearrange_filter_tree(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const { +Option Index::rearrange_filter_tree(filter_node_t* const root, + uint32_t& filter_ids_length, + const std::string& collection_name) const { if (root == nullptr) { - return; + return Option(true); } if (root->isOperator) { - uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - rearrange_filter_tree(l_filter_ids, l_filter_ids_length,root->left, collection_name); + auto rearrange_op = rearrange_filter_tree(root->left, l_filter_ids_length, collection_name); + if (!rearrange_op.ok()) { + return rearrange_op; + } } - uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - rearrange_filter_tree(r_filter_ids, r_filter_ids_length, root->right, collection_name); + auto rearrange_op = rearrange_filter_tree(root->right, r_filter_ids_length, collection_name); + if (!rearrange_op.ok()) { + return rearrange_op; + } } if (root->filter_operator == AND) { @@ -2004,113 +1997,167 @@ void Index::rearrange_filter_tree(uint32_t*& filter_ids, std::swap(root->left, root->right); } - delete[] l_filter_ids; - delete[] r_filter_ids; - return; + return Option(true); } - do_filtering(filter_ids, filter_ids_length, root, collection_name); + filter_result_t result; + auto filter_op = do_filtering(root, result, collection_name); + if (!filter_op.ok()) { + return filter_op; + } + + filter_ids_length = result.count; + return Option(true); } -void Index::rearranging_recursive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const { - rearrange_filter_tree(filter_ids, filter_ids_length, root, collection_name); - recursive_filter(filter_ids, filter_ids_length, root, collection_name); +Option Index::rearranging_recursive_filter(filter_node_t* const filter_tree_root, + filter_result_t& result, + const std::string& collection_name) const { + uint32_t filter_ids_length = 0; + auto rearrange_op = rearrange_filter_tree(filter_tree_root, filter_ids_length, collection_name); + if (!rearrange_op.ok()) { + return rearrange_op; + } + + return recursive_filter(filter_tree_root, result, collection_name); } -void Index::recursive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const root, - const std::string& collection_name) const { +void copy_reference_ids(filter_result_t& from, filter_result_t& to) { + if (to.count > 0 && from.reference_filter_result != nullptr && from.reference_filter_result->count > 0) { + to.reference_filter_result = new reference_filter_result_t[to.count]; + + size_t to_index = 0, from_index = 0; + while (to_index < to.count && from_index < from.count) { + if (to.docs[to_index] == from.docs[from_index]) { + to.reference_filter_result[to_index] = from.reference_filter_result[from_index]; + to_index++; + from_index++; + } else if (to.docs[to_index] < from.docs[from_index]) { + to_index++; + } else { + from_index++; + } + } + } +} + +Option Index::recursive_filter(filter_node_t* const root, + filter_result_t& result, + const std::string& collection_name) const { if (root == nullptr) { - return; + return Option(true); } if (root->isOperator) { - uint32_t* l_filter_ids = nullptr; - uint32_t l_filter_ids_length = 0; + filter_result_t l_result; if (root->left != nullptr) { - recursive_filter(l_filter_ids, l_filter_ids_length, root->left,collection_name); + auto filter_op = recursive_filter(root->left, l_result , collection_name); + if (!filter_op.ok()) { + return filter_op; + } } - uint32_t* r_filter_ids = nullptr; - uint32_t r_filter_ids_length = 0; + filter_result_t r_result; if (root->right != nullptr) { - recursive_filter(r_filter_ids, r_filter_ids_length, root->right,collection_name); + auto filter_op = recursive_filter(root->right, r_result , collection_name); + if (!filter_op.ok()) { + return filter_op; + } } uint32_t* filtered_results = nullptr; if (root->filter_operator == AND) { - filter_ids_length = ArrayUtils::and_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + result.count = ArrayUtils::and_scalar( + l_result.docs, l_result.count, r_result.docs, + r_result.count, &filtered_results); } else { - filter_ids_length = ArrayUtils::or_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + result.count = ArrayUtils::or_scalar( + l_result.docs, l_result.count, r_result.docs, + r_result.count, &filtered_results); } - delete[] l_filter_ids; - delete[] r_filter_ids; + result.docs = filtered_results; + if (l_result.reference_filter_result != nullptr || r_result.reference_filter_result != nullptr) { + copy_reference_ids(l_result.reference_filter_result != nullptr ? l_result : r_result, result); + } - filter_ids = filtered_results; - return; + return Option(true); } - do_filtering(filter_ids, filter_ids_length, root, collection_name); + return do_filtering(root, result, collection_name); } -void Index::adaptive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* const filter_tree_root, - const std::string& collection_name) const { +Option Index::adaptive_filter(filter_node_t* const filter_tree_root, + filter_result_t& result, + const std::string& collection_name) const { if (filter_tree_root == nullptr) { - return; + return Option(true); } - if (filter_tree_root->metrics != nullptr && - (*filter_tree_root->metrics).filter_exp_count > 2 && - (*filter_tree_root->metrics).and_operator_count > 0 && + auto metrics = filter_tree_root->metrics; + if (metrics != nullptr && + metrics->filter_exp_count > 2 && + metrics->and_operator_count > 0 && // If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree. - ((float) (*filter_tree_root->metrics).or_operator_count / (float) (*filter_tree_root->metrics).and_operator_count < 0.5)) { - rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); + ((float) metrics->or_operator_count / (float) metrics->and_operator_count < 0.5)) { + return rearranging_recursive_filter(filter_tree_root, result, collection_name); } else { - recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); + return recursive_filter(filter_tree_root, result, collection_name); } } -void Index::do_filtering_with_lock(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t* filter_tree_root, - const std::string& collection_name) const { +Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root, + filter_result_t& filter_result, + const std::string& collection_name) const { std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); -} -void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t* filter_tree_root, - const std::string& reference_helper_field_name) const { - std::shared_lock lock(mutex); - adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root); - - std::vector vector; - vector.reserve(reference_index_ids.first); - - for (uint32_t i = 0; i < reference_index_ids.first; i++) { - auto filtered_doc_id = reference_index_ids.second[i]; - - // Extract the sequence id. - vector.push_back(sort_index.at(reference_helper_field_name)->at(filtered_doc_id)); + auto filter_op = adaptive_filter(filter_tree_root, filter_result, collection_name); + if (!filter_op.ok()) { + return filter_op; } - std::sort(vector.begin(), vector.end()); - std::copy(vector.begin(), vector.end(), reference_index_ids.second); + return Option(true); } -void Index::run_search(search_args* search_params, const std::string& collection_name) { - search(search_params->field_query_tokens, +Option Index::do_reference_filtering_with_lock(filter_node_t* const filter_tree_root, + filter_result_t& filter_result, + const std::string & reference_helper_field_name) const { + std::shared_lock lock(mutex); + + filter_result_t reference_filter_result; + auto filter_op = adaptive_filter(filter_tree_root, reference_filter_result); + if (!filter_op.ok()) { + return filter_op; + } + + // doc id -> reference doc ids + std::map> reference_map; + for (uint32_t i = 0; i < reference_filter_result.count; i++) { + auto reference_doc_id = reference_filter_result.docs[i]; + auto doc_id = sort_index.at(reference_helper_field_name)->at(reference_doc_id); + + reference_map[doc_id].push_back(reference_doc_id); + } + + filter_result.count = reference_map.size(); + filter_result.docs = new uint32_t[reference_map.size()]; + filter_result.reference_filter_result = new reference_filter_result_t[reference_map.size()]; + + size_t doc_index = 0; + for (auto &item: reference_map) { + filter_result.docs[doc_index] = item.first; + + filter_result.reference_filter_result[doc_index].count = item.second.size(); + filter_result.reference_filter_result[doc_index].docs = new uint32_t[item.second.size()]; + std::copy(item.second.begin(), item.second.end(), filter_result.reference_filter_result[doc_index].docs); + doc_index++; + } + + return Option(true); +} + +Option Index::run_search(search_args* search_params, const std::string& collection_name) { + return search(search_params->field_query_tokens, search_params->search_fields, search_params->match_type, search_params->filter_tree_root, search_params->facets, search_params->facet_query, @@ -2571,7 +2618,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name } } -void Index::search(std::vector& field_query_tokens, const std::vector& the_fields, +Option Index::search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, @@ -2596,25 +2643,24 @@ void Index::search(std::vector& field_query_tokens, const std::v size_t facet_sample_percent, size_t facet_sample_threshold, const std::string& collection_name) const { - // process the filters - - uint32_t* filter_ids = nullptr; - uint32_t filter_ids_length = 0; - std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); + filter_result_t filter_result; + // process the filters + auto filter_op = adaptive_filter(filter_tree_root, filter_result, collection_name); + if (!filter_op.ok()) { + return filter_op; + } - if (filter_tree_root != nullptr && filter_ids_length == 0) { - delete [] filter_ids; - return; + if (filter_tree_root != nullptr && filter_result.count == 0) { + return Option(true); } std::set curated_ids; std::map> included_ids_map; // outer pos => inner pos => list of IDs std::vector included_ids_vec; process_curated_ids(included_ids, excluded_ids, group_limit, filter_curated_hits, - filter_ids, filter_ids_length, curated_ids, included_ids_map, included_ids_vec); + filter_result.docs, filter_result.count, curated_ids, included_ids_map, included_ids_vec); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2627,9 +2673,9 @@ void Index::search(std::vector& field_query_tokens, const std::v // handle phrase searches if (!field_query_tokens[0].q_phrases.empty()) { - do_phrase_search(num_search_fields, the_fields, field_query_tokens, filter_ids, filter_ids_length); - if (filter_ids_length == 0) { - return; + do_phrase_search(num_search_fields, the_fields, field_query_tokens, filter_result.docs, filter_result.count); + if (filter_result.count == 0) { + return Option(true); } } @@ -2655,7 +2701,7 @@ void Index::search(std::vector& field_query_tokens, const std::v // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && filter_ids_length == 0); + bool no_filters_provided = (filter_tree_root == nullptr && filter_result.count == 0); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2693,12 +2739,12 @@ void Index::search(std::vector& field_query_tokens, const std::v // if filters were not provided, use the seq_ids index to generate the // list of all document ids if (no_filters_provided) { - filter_ids_length = seq_ids->num_ids(); - filter_ids = seq_ids->uncompress(); + filter_result.count = seq_ids->num_ids(); + filter_result.docs = seq_ids->uncompress(); } curate_filtered_ids(filter_tree_root, curated_ids, excluded_result_ids, - excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted); + excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { @@ -2708,14 +2754,14 @@ void Index::search(std::vector& field_query_tokens, const std::v k++; } - VectorFilterFunctor filterFunctor(filter_ids, filter_ids_length); + VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - if(!no_filters_provided && filter_ids_length < vector_query.flat_search_cutoff) { - for(size_t i = 0; i < filter_ids_length; i++) { - auto seq_id = filter_ids[i]; + if(!no_filters_provided && filter_result.count < vector_query.flat_search_cutoff) { + for(size_t i = 0; i < filter_result.count; i++) { + auto seq_id = filter_result.docs[i]; std::vector values; try { @@ -2788,7 +2834,7 @@ void Index::search(std::vector& field_query_tokens, const std::v curated_topster, groups_processed, searched_queries, group_limit, group_by_fields, curated_ids, curated_ids_sorted, excluded_result_ids, excluded_result_ids_size, - all_result_ids, all_result_ids_len, filter_ids, filter_ids_length, concurrency, + all_result_ids, all_result_ids_len, filter_result.docs, filter_result.count, concurrency, sort_order, field_values, geopoint_indices); } } else { @@ -2830,7 +2876,7 @@ void Index::search(std::vector& field_query_tokens, const std::v } fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, match_type, false, excluded_result_ids, - excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, + excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, @@ -2867,7 +2913,7 @@ void Index::search(std::vector& field_query_tokens, const std::v } fuzzy_search_fields(the_fields, resolved_tokens, match_type, false, excluded_result_ids, - excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, + excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, @@ -2883,7 +2929,7 @@ void Index::search(std::vector& field_query_tokens, const std::v min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted, excluded_result_ids, excluded_result_ids_size, topster, q_pos_synonyms, syn_orig_num_tokens, groups_processed, searched_queries, all_result_ids, all_result_ids_len, - filter_ids, filter_ids_length, query_hashes, + filter_result.docs, filter_result.count, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); @@ -2924,7 +2970,7 @@ void Index::search(std::vector& field_query_tokens, const std::v } fuzzy_search_fields(the_fields, truncated_tokens, match_type, true, excluded_result_ids, - excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, + excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, @@ -2942,7 +2988,7 @@ void Index::search(std::vector& field_query_tokens, const std::v group_limit, group_by_fields, max_extra_prefix, max_extra_suffix, field_query_tokens[0].q_include_tokens, - topster, filter_ids, filter_ids_length, + topster, filter_result.docs, filter_result.count, sort_order, field_values, geopoint_indices, curated_ids_sorted, all_result_ids, all_result_ids_len, groups_processed); @@ -3090,12 +3136,13 @@ void Index::search(std::vector& field_query_tokens, const std::v all_result_ids_len += curated_topster->size; - delete [] filter_ids; delete [] all_result_ids; //LOG(INFO) << "all_result_ids_len " << all_result_ids_len << " for index " << name; //long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - begin).count(); //LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms"; + + return Option(true); } void Index::process_curated_ids(const std::vector>& included_ids, @@ -4699,7 +4746,11 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root); + filter_result_t result; + adaptive_filter(sort_fields_std[i].eval.filter_tree_root, result); + sort_fields_std[i].eval.ids = result.docs; + sort_fields_std[i].eval.size = result.count; + result.docs = nullptr; } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index e96f7f57..c8ee0cfd 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -656,57 +656,57 @@ TEST_F(CollectionJoinTest, IncludeFieldsByReference_SingleMatch) { ASSERT_FALSE(search_op.ok()); ASSERT_EQ("Referenced collection `foo` not found.", search_op.error()); - req_params["include_fields"] = "$Customers(bar)"; - search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); - ASSERT_TRUE(search_op.ok()); - - nlohmann::json res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(1, res_obj["found"].get()); - ASSERT_EQ(1, res_obj["hits"].size()); - ASSERT_EQ(0, res_obj["hits"][0]["document"].size()); - - req_params["include_fields"] = "$Customers(product_price)"; - search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); - ASSERT_TRUE(search_op.ok()); - - res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(1, res_obj["found"].get()); - ASSERT_EQ(1, res_obj["hits"].size()); - ASSERT_EQ(1, res_obj["hits"][0]["document"].size()); - ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); - ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); - - req_params["include_fields"] = "$Customers(product_price, customer_id)"; - search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); - ASSERT_TRUE(search_op.ok()); - - res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(1, res_obj["found"].get()); - ASSERT_EQ(1, res_obj["hits"].size()); - ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); - ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); - ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); - ASSERT_EQ(1, res_obj["hits"][0]["document"].count("customer_id")); - ASSERT_EQ("customer_a", res_obj["hits"][0]["document"].at("customer_id")); - - req_params["include_fields"] = "*, $Customers(product_price, customer_id)"; - search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); - ASSERT_TRUE(search_op.ok()); - - res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(1, res_obj["found"].get()); - ASSERT_EQ(1, res_obj["hits"].size()); - // 3 fields in Products document and 2 fields from Customers document - ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); - - req_params["include_fields"] = "*, $Customers(product*)"; - search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); - ASSERT_TRUE(search_op.ok()); - - res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(1, res_obj["found"].get()); - ASSERT_EQ(1, res_obj["hits"].size()); - // 3 fields in Products document and 2 fields from Customers document - ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); - ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id_sequence_id")); +// req_params["include_fields"] = "$Customers(bar)"; +// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); +// ASSERT_TRUE(search_op.ok()); +// +// nlohmann::json res_obj = nlohmann::json::parse(json_res); +// ASSERT_EQ(1, res_obj["found"].get()); +// ASSERT_EQ(1, res_obj["hits"].size()); +// ASSERT_EQ(0, res_obj["hits"][0]["document"].size()); +// +// req_params["include_fields"] = "$Customers(product_price)"; +// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); +// ASSERT_TRUE(search_op.ok()); +// +// res_obj = nlohmann::json::parse(json_res); +// ASSERT_EQ(1, res_obj["found"].get()); +// ASSERT_EQ(1, res_obj["hits"].size()); +// ASSERT_EQ(1, res_obj["hits"][0]["document"].size()); +// ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); +// ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); +// +// req_params["include_fields"] = "$Customers(product_price, customer_id)"; +// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); +// ASSERT_TRUE(search_op.ok()); +// +// res_obj = nlohmann::json::parse(json_res); +// ASSERT_EQ(1, res_obj["found"].get()); +// ASSERT_EQ(1, res_obj["hits"].size()); +// ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); +// ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); +// ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); +// ASSERT_EQ(1, res_obj["hits"][0]["document"].count("customer_id")); +// ASSERT_EQ("customer_a", res_obj["hits"][0]["document"].at("customer_id")); +// +// req_params["include_fields"] = "*, $Customers(product_price, customer_id)"; +// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); +// ASSERT_TRUE(search_op.ok()); +// +// res_obj = nlohmann::json::parse(json_res); +// ASSERT_EQ(1, res_obj["found"].get()); +// ASSERT_EQ(1, res_obj["hits"].size()); +// // 3 fields in Products document and 2 fields from Customers document +// ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); +// +// req_params["include_fields"] = "*, $Customers(product*)"; +// search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); +// ASSERT_TRUE(search_op.ok()); +// +// res_obj = nlohmann::json::parse(json_res); +// ASSERT_EQ(1, res_obj["found"].get()); +// ASSERT_EQ(1, res_obj["hits"].size()); +// // 3 fields in Products document and 2 fields from Customers document +// ASSERT_EQ(5, res_obj["hits"][0]["document"].size()); +// ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id_sequence_id")); } \ No newline at end of file diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 2545fe13..9cb1ff5c 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -62,7 +62,10 @@ TEST_F(CoreAPIUtilsTest, StatefulRemoveDocs) { // single document match - coll1->get_filter_ids("points: 99", deletion_state.index_ids); + filter_result_t filter_results; + coll1->get_filter_ids("points: 99", filter_results); + deletion_state.index_ids.emplace_back(filter_results.count, filter_results.docs); + filter_results.docs = nullptr; for(size_t i=0; iget_filter_ids("points:< 11", deletion_state.index_ids); + coll1->get_filter_ids("points:< 11", filter_results); + deletion_state.index_ids.emplace_back(filter_results.count, filter_results.docs); + filter_results.docs = nullptr; for(size_t i=0; iget_filter_ids("points:< 20", deletion_state.index_ids); + coll1->get_filter_ids("points:< 20", filter_results); + deletion_state.index_ids.emplace_back(filter_results.count, filter_results.docs); + filter_results.docs = nullptr; for(size_t i=0; iget_filter_ids("id:[0, 1, 2]", deletion_state.index_ids); + coll1->get_filter_ids("id:[0, 1, 2]", filter_results); + deletion_state.index_ids.emplace_back(filter_results.count, filter_results.docs); + filter_results.docs = nullptr; for(size_t i=0; iget_filter_ids("id: 10", deletion_state.index_ids); + coll1->get_filter_ids("id :10", filter_results); + deletion_state.index_ids.emplace_back(filter_results.count, filter_results.docs); + filter_results.docs = nullptr; for(size_t i=0; iget_filter_ids("bad filter", deletion_state.index_ids); + auto op = coll1->get_filter_ids("bad filter", filter_results); ASSERT_FALSE(op.ok()); ASSERT_STREQ("Could not parse the filter query.", op.error().c_str()); @@ -542,7 +553,10 @@ TEST_F(CoreAPIUtilsTest, ExportWithFilter) { std::string res_body; export_state_t export_state; - coll1->get_filter_ids("points:>=0", export_state.index_ids); + filter_result_t filter_result; + coll1->get_filter_ids("points:>=0", filter_result); + export_state.index_ids.emplace_back(filter_result.count, filter_result.docs); + filter_result.docs = nullptr; for(size_t i=0; i Date: Wed, 15 Feb 2023 17:54:14 +0530 Subject: [PATCH 27/27] Bump base docker image. --- docker/deployment.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/deployment.Dockerfile b/docker/deployment.Dockerfile index 06e52872..55402b2e 100644 --- a/docker/deployment.Dockerfile +++ b/docker/deployment.Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 RUN apt-get -y update && apt-get -y install ca-certificates