From f1d0f279c79a7fb8f7d35788034f994fad317b95 Mon Sep 17 00:00:00 2001 From: kishorenc Date: Thu, 21 May 2020 19:40:27 +0530 Subject: [PATCH] Support for pinning and hiding hits during query time. --- include/collection.h | 9 +++- include/string_utils.h | 4 ++ src/collection.cpp | 39 +++++++++++++++-- src/core_api.cpp | 41 +++++++++++++++++- test/collection_override_test.cpp | 70 +++++++++++++++++++++++++++++++ 5 files changed, 156 insertions(+), 7 deletions(-) diff --git a/include/collection.h b/include/collection.h index 71afb27d..36c173ad 100644 --- a/include/collection.h +++ b/include/collection.h @@ -154,7 +154,10 @@ private: void remove_document(nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - void populate_overrides(std::string query, std::map & id_pos_map, + void populate_overrides(std::string query, + const std::map& pinned_hits, + const std::vector& hidden_hits, + std::map & id_pos_map, std::vector & included_ids, std::vector & excluded_ids); static bool facet_count_compare(const std::pair& a, @@ -230,7 +233,9 @@ public: const std::string & simple_facet_query = "", const size_t snippet_threshold = 30, const std::string & highlight_full_fields = "", - size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD); + size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, + const std::map& pinned_hits={}, + const std::vector& hidden_hits={}); Option get(const std::string & id); diff --git a/include/string_utils.h b/include/string_utils.h index 844ee383..5a19b3ca 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -140,6 +140,10 @@ struct StringUtils { return (*p == 0); } + static bool is_positive_integer(const std::string& s) { + return !s.empty() && s.find_first_not_of("0123456789") == std::string::npos; + } + // Adapted from: http://stackoverflow.com/a/2845275/131050 static bool is_uint64_t(const std::string &s) { if(s.empty()) { diff --git a/src/collection.cpp b/src/collection.cpp index bd077e05..ce8ab1a7 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -281,10 +281,39 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash } } -void Collection::populate_overrides(std::string query, std::map & id_pos_map, - std::vector & included_ids, std::vector & excluded_ids) { +void Collection::populate_overrides(std::string query, + const std::map& pinned_hits, + const std::vector& hidden_hits, + std::map & id_pos_map, + std::vector & included_ids, + std::vector & excluded_ids) { StringUtils::tolowercase(query); + // NOTE: if pinned or hidden hits are provided, then overrides will be just ignored + + if(!pinned_hits.empty()) { + for(const auto & hit: pinned_hits) { + Option seq_id_op = doc_id_to_seq_id(hit.first); + if(seq_id_op.ok()) { + included_ids.push_back(seq_id_op.get()); + id_pos_map[seq_id_op.get()] = hit.second; + } + } + } + + if(!hidden_hits.empty()) { + for(const auto & hit: hidden_hits) { + Option seq_id_op = doc_id_to_seq_id(hit); + if(seq_id_op.ok()) { + excluded_ids.push_back(seq_id_op.get()); + } + } + } + + if(!hidden_hits.empty() || !pinned_hits.empty()) { + return ; + } + for(const auto & override_kv: overrides) { const auto & override = override_kv.second; @@ -320,12 +349,14 @@ Option Collection::search(const std::string & query, const std:: const std::string & simple_facet_query, const size_t snippet_threshold, const std::string & highlight_full_fields, - size_t typo_tokens_threshold ) { + size_t typo_tokens_threshold, + const std::map& pinned_hits, + const std::vector& hidden_hits) { std::vector included_ids; std::vector excluded_ids; std::map id_pos_map; - populate_overrides(query, id_pos_map, included_ids, excluded_ids); + populate_overrides(query, pinned_hits, hidden_hits, id_pos_map, included_ids, excluded_ids); std::map> index_to_included_ids; std::map> index_to_excluded_ids; diff --git a/src/core_api.cpp b/src/core_api.cpp index 489b9f4e..3a7c939f 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -243,6 +243,9 @@ bool get_search(http_req & req, http_res & res) { const char *INCLUDE_FIELDS = "include_fields"; const char *EXCLUDE_FIELDS = "exclude_fields"; + const char *PINNED_HITS = "pinned_hits"; + const char *HIDDEN_HITS = "hidden_hits"; + // strings under this length will be fully highlighted, instead of showing a snippet of relevant portion const char *SNIPPET_THRESHOLD = "snippet_threshold"; @@ -385,6 +388,40 @@ bool get_search(http_req & req, http_res & res) { } } + std::map pinned_hits; + if(req.params.count(PINNED_HITS) != 0) { + std::vector pinned_hits_strs; + StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ","); + + for(const std::string & pinned_hits_str: pinned_hits_strs) { + std::vector expression_parts; + StringUtils::split(pinned_hits_str, expression_parts, ":"); + + if(expression_parts.size() != 2) { + res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); + return false; + } + + if(!StringUtils::is_positive_integer(expression_parts[1])) { + res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); + return false; + } + + int position = std::stoi(expression_parts[1]); + if(position == 0) { + res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); + return false; + } + + pinned_hits.emplace(expression_parts[0], position); + } + } + + std::vector hidden_hits; + if(req.params.count(HIDDEN_HITS) != 0) { + StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ","); + } + CollectionManager & collectionManager = CollectionManager::get_instance(); Collection* collection = collectionManager.get_collection(req.params["collection"]); @@ -415,7 +452,9 @@ bool get_search(http_req & req, http_res & res) { req.params[FACET_QUERY], static_cast(std::stoi(req.params[SNIPPET_THRESHOLD])), req.params[HIGHLIGHT_FULL_FIELDS], - typo_tokens_threshold + typo_tokens_threshold, + pinned_hits, + hidden_hits ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index a2539da6..d23c886e 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -251,4 +251,74 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); coll_mul_fields->remove_override("include-rule"); +} + +TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { + std::map pinned_hits; + std::vector hidden_hits; + pinned_hits["13"] = 1; + pinned_hits["4"] = 2; + + // basic pinning + + auto results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, 500, "starring: will", 30, + "", 10, + pinned_hits, {}).get(); + + ASSERT_EQ(10, results["found"].get()); + ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_STREQ("16", results["hits"][3]["document"]["id"].get().c_str()); + ASSERT_STREQ("6", results["hits"][4]["document"]["id"].get().c_str()); + + // both pinning and hiding + + hidden_hits = {"11", "16"}; + results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, 500, "starring: will", 30, + "", 10, + pinned_hits, hidden_hits).get(); + + ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get().c_str()); + + // take precedence over override rules + + nlohmann::json override_json_include = { + {"id", "include-rule"}, + { + "rule", { + {"query", "the"}, + {"match", override_t::MATCH_EXACT} + } + } + }; + + // trying to include an ID that is also being hidden via `hidden_hits` query param will not work + // as if pinned or hidden hits are provided, overrides will be entirely skipped + override_json_include["includes"] = nlohmann::json::array(); + override_json_include["includes"][0] = nlohmann::json::object(); + override_json_include["includes"][0]["id"] = "11"; + override_json_include["includes"][0]["position"] = 1; + + override_t override_include(override_json_include); + coll_mul_fields->add_override(override_include); + + results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, 500, "starring: will", 30, + "", 10, + {}, {hidden_hits}).get(); + + ASSERT_EQ(8, results["found"].get()); + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("8", results["hits"][1]["document"]["id"].get().c_str()); } \ No newline at end of file