diff --git a/include/collection.h b/include/collection.h index bcfa8150..7a77699b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -186,6 +186,9 @@ private: Option parse_filter_query(const std::string& simple_filter_query, std::vector& filters); + Option parse_pinned_hits(const std::string& pinned_hits_str, + std::map>& pinned_hits); + public: Collection() = delete; @@ -254,8 +257,8 @@ public: const size_t highlight_affix_num_tokens = 4, const std::string & highlight_full_fields = "", size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, - const std::map>& pinned_hits={}, - const std::vector& hidden_hits={}, + const std::string& pinned_hits_str="", + const std::string& hidden_hits="", const std::vector& group_by_fields={}, const size_t group_limit = 0, const std::string& highlight_start_tag="", diff --git a/src/collection.cpp b/src/collection.cpp index 06a35bbe..49ef4129 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -505,8 +505,8 @@ Option Collection::search(const std::string & query, const std:: const size_t highlight_affix_num_tokens, const std::string & highlight_full_fields, size_t typo_tokens_threshold, - const std::map>& pinned_hits, - const std::vector& hidden_hits, + const std::string& pinned_hits_str, + const std::string& hidden_hits_str, const std::vector& group_by_fields, const size_t group_limit, const std::string& highlight_start_tag, @@ -523,6 +523,17 @@ Option Collection::search(const std::string & query, const std:: std::vector excluded_ids; std::map> include_ids; // position => list of IDs + std::map> pinned_hits; + + Option pinned_hits_op = parse_pinned_hits(pinned_hits_str, pinned_hits); + + if(!pinned_hits_op.ok()) { + return Option(400, pinned_hits_op.error()); + } + + std::vector hidden_hits; + StringUtils::split(hidden_hits_str, hidden_hits, ","); + populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids); /*for(auto kv: include_ids) { @@ -1810,3 +1821,41 @@ Option Collection::parse_filter_query(const std::string& simple_filter_que return Option(true); } + +Option Collection::parse_pinned_hits(const std::string& pinned_hits_str, + std::map>& pinned_hits) { + if(!pinned_hits_str.empty()) { + std::vector pinned_hits_strs; + StringUtils::split(pinned_hits_str, pinned_hits_strs, ","); + + for(const std::string & pinned_hits_part: pinned_hits_strs) { + std::vector expression_parts; + size_t index = pinned_hits_part.size() - 1; + while(index >= 0 && pinned_hits_part[index] != ':') { + index--; + } + + if(index == 0) { + return Option(false, "Pinned hits are not in expected format."); + } + + std::string pinned_id = pinned_hits_part.substr(0, index); + std::string pinned_pos = pinned_hits_part.substr(index+1); + + if(!StringUtils::is_positive_integer(pinned_pos)) { + return Option(false, "Pinned hits are not in expected format."); + return false; + } + + int position = std::stoi(pinned_pos); + if(position == 0) { + return Option(false, "Pinned hits must start from position 1."); + return false; + } + + pinned_hits[position].emplace_back(pinned_id); + } + } + + return Option(true); +} diff --git a/src/core_api.cpp b/src/core_api.cpp index 17f1fcc1..76644b85 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -447,39 +447,12 @@ bool get_search(http_req & req, http_res & res) { } } - std::map> pinned_hits; - - if(req.params.count(PINNED_HITS) != 0) { - std::vector pinned_hits_strs; - StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ","); - - for(const std::string & pinned_hits_str: pinned_hits_strs) { - std::vector expression_parts; - StringUtils::split(pinned_hits_str, expression_parts, ":"); - - if(expression_parts.size() != 2) { - res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); - return false; - } - - if(!StringUtils::is_positive_integer(expression_parts[1])) { - res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); - return false; - } - - int position = std::stoi(expression_parts[1]); - if(position == 0) { - res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed."); - return false; - } - - pinned_hits[position].emplace_back(expression_parts[0]); - } + if(req.params.count(PINNED_HITS) == 0) { + req.params[PINNED_HITS] = ""; } - std::vector hidden_hits; - if(req.params.count(HIDDEN_HITS) != 0) { - StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ","); + if(req.params.count(HIDDEN_HITS) == 0) { + req.params[HIDDEN_HITS] = ""; } CollectionManager & collectionManager = CollectionManager::get_instance(); @@ -513,8 +486,8 @@ bool get_search(http_req & req, http_res & res) { static_cast(std::stol(req.params[HIGHLIGHT_AFFIX_NUM_TOKENS])), req.params[HIGHLIGHT_FULL_FIELDS], typo_tokens_threshold, - pinned_hits, - hidden_hits, + req.params[PINNED_HITS], + req.params[HIDDEN_HITS], group_by_fields, static_cast(std::stol(req.params[GROUP_LIMIT])), req.params[HIGHLIGHT_START_TAG], diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 865886d3..7c08be29 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -267,9 +267,7 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) { } TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { - std::map> pinned_hits; - pinned_hits[1] = {"13"}; - pinned_hits[2] = {"4"}; + auto pinned_hits = "13:1,4:2"; // basic pinning @@ -289,8 +287,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { // both pinning and hiding - std::vector hidden_hits; - hidden_hits = {"11", "16"}; + std::string hidden_hits="11,16"; results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 50, 1, FREQUENCY, false, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), @@ -303,9 +300,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get().c_str()); // paginating such that pinned hits appear on second page - pinned_hits.clear(); - pinned_hits[4] = {"13"}; - pinned_hits[5] = {"4"}; + pinned_hits = "13:4,4:5"; results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 2, 2, FREQUENCY, false, Index::DROP_TOKENS_THRESHOLD, @@ -356,10 +351,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { } TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) { - std::map> pinned_hits; - pinned_hits[1] = {"17"}; - pinned_hits[4] = {"13"}; - pinned_hits[3] = {"11"}; + auto pinned_hits = "17:1,13:4,11:3"; // pinned hits larger than page size: check that pagination works @@ -400,11 +392,7 @@ TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) { } TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) { - std::map> pinned_hits; - pinned_hits[1] = {"6"}; - pinned_hits[2] = {"1"}; - pinned_hits[3] = {"16"}; - pinned_hits[4] = {"11"}; + auto pinned_hits = "6:1,1:2,16:3,11:4"; // pinned hits larger than page size: check that pagination works @@ -445,11 +433,43 @@ TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) { ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); } +TEST_F(CollectionOverrideTest, PinnedHitsWhenThereAreNotEnoughResults) { + auto pinned_hits = "6:1,1:2,11:5"; + + // multiple pinnned hits specified, but query produces no result + + auto results = coll_mul_fields->search("notfoundquery", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "starring: will", 30, 5, + "", 10, + pinned_hits, {}).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get().c_str()); + + // multiple pinned hits but only single result + results = coll_mul_fields->search("burgundy", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "starring: will", 30, 5, + "", 10, + pinned_hits, {}).get(); + + ASSERT_EQ(4, results["found"].get()); + ASSERT_EQ(4, results["hits"].size()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_STREQ("11", results["hits"][3]["document"]["id"].get().c_str()); +} + TEST_F(CollectionOverrideTest, PinnedHitsGrouping) { - std::map> pinned_hits; - pinned_hits[1] = {"6", "8"}; - pinned_hits[2] = {"1"}; - pinned_hits[3] = {"13", "4"}; + auto pinned_hits = "6:1,8:1,1:2,13:3,4:3"; // without any grouping parameter, only the first ID in a position should be picked // and other IDs should appear in their original positions @@ -499,4 +519,52 @@ TEST_F(CollectionOverrideTest, PinnedHitsGrouping) { ASSERT_STREQ("11", results["grouped_hits"][3]["hits"][0]["document"]["id"].get().c_str()); ASSERT_STREQ("16", results["grouped_hits"][4]["hits"][0]["document"]["id"].get().c_str()); +} + +TEST_F(CollectionOverrideTest, PinnedHitsIdsHavingColon) { + Collection *coll1; + + std::vector fields = {field("url", field_types::STRING, true), + field("points", field_types::INT32, false)}; + + std::vector sort_fields = { sort_by("points", "DESC") }; + + coll1 = collectionManager.get_collection("coll1"); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 4, fields, "points").get(); + } + + for(size_t i=1; i<=10; i++) { + nlohmann::json doc; + doc["id"] = std::string("https://example.com/") + std::to_string(i); + doc["url"] = std::string("https://example.com/") + std::to_string(i); + doc["points"] = i; + + coll1->add(doc.dump()); + } + + std::vector query_fields = {"url"}; + std::vector facets; + + std::string pinned_hits_str = "https://example.com/1:1, https://example.com/3:2"; // can have space + + auto res_op = coll1->search("*", {"url"}, "", {}, {}, 0, 25, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + pinned_hits_str, {}); + + ASSERT_TRUE(res_op.ok()); + + auto res = res_op.get(); + + ASSERT_EQ(10, res["found"].get()); + ASSERT_STREQ("https://example.com/1", res["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("https://example.com/3", res["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("https://example.com/10", res["hits"][2]["document"]["id"].get().c_str()); + ASSERT_STREQ("https://example.com/9", res["hits"][3]["document"]["id"].get().c_str()); + ASSERT_STREQ("https://example.com/2", res["hits"][9]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); } \ No newline at end of file