From 0d5eef664e2ef15e1a14efdb53b007f552f3276a Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 11 Jun 2021 20:29:21 +0530 Subject: [PATCH] Ability to disable overrides. --- include/collection.h | 14 ++++---- src/collection.cpp | 58 +++++++++++++++++-------------- src/collection_manager.cpp | 10 +++++- test/collection_override_test.cpp | 26 ++++++++++++++ 4 files changed, 74 insertions(+), 34 deletions(-) diff --git a/include/collection.h b/include/collection.h index 7be7842c..2033c9b3 100644 --- a/include/collection.h +++ b/include/collection.h @@ -346,11 +346,12 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - void populate_overrides(std::string query, - const std::map>& pinned_hits, - const std::vector& hidden_hits, - std::map>& include_ids, - std::vector & excluded_ids) const; + void curate_results(std::string query, + const std::map>& pinned_hits, + const std::vector& hidden_hits, + std::map>& include_ids, + std::vector & excluded_ids, + bool enable_overrides) const; Option check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values); @@ -532,7 +533,8 @@ public: std::vector query_by_weights={}, size_t limit_hits=UINT32_MAX, bool prioritize_exact_match=true, - bool pre_segmented_query=false) const; + bool pre_segmented_query=false, + bool enable_overrides=true) const; Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/src/collection.cpp b/src/collection.cpp index 1ac59f11..da1c59da 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -414,11 +414,12 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash } } -void Collection::populate_overrides(std::string query, - const std::map>& pinned_hits, - const std::vector& hidden_hits, - std::map>& include_ids, - std::vector & excluded_ids) const { +void Collection::curate_results(std::string query, + const std::map>& pinned_hits, + const std::vector& hidden_hits, + std::map>& include_ids, + std::vector & excluded_ids, + bool enable_overrides) const { StringUtils::tolowercase(query); std::set excluded_set; @@ -435,30 +436,32 @@ void Collection::populate_overrides(std::string query, } } - for(const auto & override_kv: overrides) { - const auto & override = override_kv.second; + if(enable_overrides) { + for(const auto & override_kv: overrides) { + const auto & override = override_kv.second; - if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || - (override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) { + if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || + (override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) { - // have to ensure that dropped hits take precedence over added hits - for(const auto & hit: override.drop_hits) { - Option seq_id_op = doc_id_to_seq_id(hit.doc_id); - if(seq_id_op.ok()) { - excluded_ids.push_back(seq_id_op.get()); - excluded_set.insert(seq_id_op.get()); + // have to ensure that dropped hits take precedence over added hits + for(const auto & hit: override.drop_hits) { + Option seq_id_op = doc_id_to_seq_id(hit.doc_id); + if(seq_id_op.ok()) { + excluded_ids.push_back(seq_id_op.get()); + excluded_set.insert(seq_id_op.get()); + } } - } - for(const auto & hit: override.add_hits) { - Option seq_id_op = doc_id_to_seq_id(hit.doc_id); - if(!seq_id_op.ok()) { - continue; - } - uint32_t seq_id = seq_id_op.get(); - bool excluded = (excluded_set.count(seq_id) != 0); - if(!excluded) { - include_ids[hit.position].push_back(seq_id); + for(const auto & hit: override.add_hits) { + Option seq_id_op = doc_id_to_seq_id(hit.doc_id); + if(!seq_id_op.ok()) { + continue; + } + uint32_t seq_id = seq_id_op.get(); + bool excluded = (excluded_set.count(seq_id) != 0); + if(!excluded) { + include_ids[hit.position].push_back(seq_id); + } } } } @@ -505,7 +508,8 @@ Option Collection::search(const std::string & query, const std:: std::vector query_by_weights, size_t limit_hits, bool prioritize_exact_match, - bool pre_segmented_query) const { + bool pre_segmented_query, + bool enable_overrides) const { std::shared_lock lock(mutex); @@ -550,7 +554,7 @@ Option Collection::search(const std::string & query, const std:: std::vector hidden_hits; StringUtils::split(hidden_hits_str, hidden_hits, ","); - populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids); + curate_results(query, pinned_hits, hidden_hits, include_ids, excluded_ids, enable_overrides); /*for(auto& kv: include_ids) { LOG(INFO) << "key: " << kv.first; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 4d328534..c504cd72 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -485,6 +485,7 @@ Option CollectionManager::do_search(std::map& re const char *PINNED_HITS = "pinned_hits"; const char *HIDDEN_HITS = "hidden_hits"; + const char *ENABLE_OVERRIDES = "enable_overrides"; // strings under this length will be fully highlighted, instead of showing a snippet of relevant portion const char *SNIPPET_THRESHOLD = "snippet_threshold"; @@ -705,6 +706,12 @@ Option CollectionManager::do_search(std::map& re req_params[HIDDEN_HITS] = ""; } + if(req_params.count(ENABLE_OVERRIDES) == 0) { + req_params[ENABLE_OVERRIDES] = "true"; + } + + bool enable_overrides = (req_params[ENABLE_OVERRIDES] == "true"); + CollectionManager & collectionManager = CollectionManager::get_instance(); auto collection = collectionManager.get_collection(req_params["collection"]); @@ -760,7 +767,8 @@ Option CollectionManager::do_search(std::map& re query_by_weights, static_cast(std::stol(req_params[LIMIT_HITS])), prioritize_exact_match, - pre_segmented_query + pre_segmented_query, + enable_overrides ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index fc40fef4..ca33b70d 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -170,6 +170,32 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeExactQueryMatch) { ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get().c_str()); ASSERT_STREQ("1", results["hits"][3]["document"]["id"].get().c_str()); + // ability to disable overrides + bool enable_overrides = false; + res_op = coll_mul_fields->search("will", {"title"}, "", {}, {}, {0}, 10, + 1, FREQUENCY, {false}, 0, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 0, {}, {}, {}, 0, + "", "", {1}, 10000, true, false, enable_overrides); + ASSERT_TRUE(res_op.ok()); + results = res_op.get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ(2, results["found"].get()); + + ASSERT_STREQ("3", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get().c_str()); + + enable_overrides = true; + res_op = coll_mul_fields->search("will", {"title"}, "", {}, {}, {0}, 10, + 1, FREQUENCY, {false}, 0, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 0, {}, {}, {}, 0, + "", "", {1}, 10000, true, false, enable_overrides); + ASSERT_TRUE(res_op.ok()); + results = res_op.get(); + + ASSERT_EQ(4, results["hits"].size()); + ASSERT_EQ(4, results["found"].get()); + coll_mul_fields->remove_override("include-rule"); }