Handle colon character in pinned hit IDs.

This commit is contained in:
kishorenc 2020-11-18 07:33:01 +05:30
parent 254c58dd31
commit 3a86d58506
4 changed files with 151 additions and 58 deletions

View File

@ -186,6 +186,9 @@ private:
Option<bool> parse_filter_query(const std::string& simple_filter_query, std::vector<filter>& filters);
Option<bool> parse_pinned_hits(const std::string& pinned_hits_str,
std::map<size_t, std::vector<std::string>>& pinned_hits);
public:
Collection() = delete;
@ -254,8 +257,8 @@ public:
const size_t highlight_affix_num_tokens = 4,
const std::string & highlight_full_fields = "",
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
const std::map<size_t, std::vector<std::string>>& pinned_hits={},
const std::vector<std::string>& hidden_hits={},
const std::string& pinned_hits_str="",
const std::string& hidden_hits="",
const std::vector<std::string>& group_by_fields={},
const size_t group_limit = 0,
const std::string& highlight_start_tag="<mark>",

View File

@ -505,8 +505,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const size_t highlight_affix_num_tokens,
const std::string & highlight_full_fields,
size_t typo_tokens_threshold,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
const std::string& pinned_hits_str,
const std::string& hidden_hits_str,
const std::vector<std::string>& group_by_fields,
const size_t group_limit,
const std::string& highlight_start_tag,
@ -523,6 +523,17 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
std::vector<uint32_t> excluded_ids;
std::map<size_t, std::vector<uint32_t>> include_ids; // position => list of IDs
std::map<size_t, std::vector<std::string>> pinned_hits;
Option<bool> pinned_hits_op = parse_pinned_hits(pinned_hits_str, pinned_hits);
if(!pinned_hits_op.ok()) {
return Option<nlohmann::json>(400, pinned_hits_op.error());
}
std::vector<std::string> hidden_hits;
StringUtils::split(hidden_hits_str, hidden_hits, ",");
populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids);
/*for(auto kv: include_ids) {
@ -1810,3 +1821,41 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
return Option<bool>(true);
}
Option<bool> Collection::parse_pinned_hits(const std::string& pinned_hits_str,
std::map<size_t, std::vector<std::string>>& pinned_hits) {
if(!pinned_hits_str.empty()) {
std::vector<std::string> pinned_hits_strs;
StringUtils::split(pinned_hits_str, pinned_hits_strs, ",");
for(const std::string & pinned_hits_part: pinned_hits_strs) {
std::vector<std::string> expression_parts;
size_t index = pinned_hits_part.size() - 1;
while(index >= 0 && pinned_hits_part[index] != ':') {
index--;
}
if(index == 0) {
return Option<bool>(false, "Pinned hits are not in expected format.");
}
std::string pinned_id = pinned_hits_part.substr(0, index);
std::string pinned_pos = pinned_hits_part.substr(index+1);
if(!StringUtils::is_positive_integer(pinned_pos)) {
return Option<bool>(false, "Pinned hits are not in expected format.");
return false;
}
int position = std::stoi(pinned_pos);
if(position == 0) {
return Option<bool>(false, "Pinned hits must start from position 1.");
return false;
}
pinned_hits[position].emplace_back(pinned_id);
}
}
return Option<bool>(true);
}

View File

@ -447,39 +447,12 @@ bool get_search(http_req & req, http_res & res) {
}
}
std::map<size_t, std::vector<std::string>> pinned_hits;
if(req.params.count(PINNED_HITS) != 0) {
std::vector<std::string> pinned_hits_strs;
StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ",");
for(const std::string & pinned_hits_str: pinned_hits_strs) {
std::vector<std::string> expression_parts;
StringUtils::split(pinned_hits_str, expression_parts, ":");
if(expression_parts.size() != 2) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
if(!StringUtils::is_positive_integer(expression_parts[1])) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
int position = std::stoi(expression_parts[1]);
if(position == 0) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
pinned_hits[position].emplace_back(expression_parts[0]);
}
if(req.params.count(PINNED_HITS) == 0) {
req.params[PINNED_HITS] = "";
}
std::vector<std::string> hidden_hits;
if(req.params.count(HIDDEN_HITS) != 0) {
StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ",");
if(req.params.count(HIDDEN_HITS) == 0) {
req.params[HIDDEN_HITS] = "";
}
CollectionManager & collectionManager = CollectionManager::get_instance();
@ -513,8 +486,8 @@ bool get_search(http_req & req, http_res & res) {
static_cast<size_t>(std::stol(req.params[HIGHLIGHT_AFFIX_NUM_TOKENS])),
req.params[HIGHLIGHT_FULL_FIELDS],
typo_tokens_threshold,
pinned_hits,
hidden_hits,
req.params[PINNED_HITS],
req.params[HIDDEN_HITS],
group_by_fields,
static_cast<size_t>(std::stol(req.params[GROUP_LIMIT])),
req.params[HIGHLIGHT_START_TAG],

View File

@ -267,9 +267,7 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
}
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
std::map<size_t, std::vector<std::string>> pinned_hits;
pinned_hits[1] = {"13"};
pinned_hits[2] = {"4"};
auto pinned_hits = "13:1,4:2";
// basic pinning
@ -289,8 +287,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
// both pinning and hiding
std::vector<std::string> hidden_hits;
hidden_hits = {"11", "16"};
std::string hidden_hits="11,16";
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 50, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
@ -303,9 +300,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get<std::string>().c_str());
// paginating such that pinned hits appear on second page
pinned_hits.clear();
pinned_hits[4] = {"13"};
pinned_hits[5] = {"4"};
pinned_hits = "13:4,4:5";
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 2, 2, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
@ -356,10 +351,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
}
TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) {
std::map<size_t, std::vector<std::string>> pinned_hits;
pinned_hits[1] = {"17"};
pinned_hits[4] = {"13"};
pinned_hits[3] = {"11"};
auto pinned_hits = "17:1,13:4,11:3";
// pinned hits larger than page size: check that pagination works
@ -400,11 +392,7 @@ TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) {
}
TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) {
std::map<size_t, std::vector<std::string>> pinned_hits;
pinned_hits[1] = {"6"};
pinned_hits[2] = {"1"};
pinned_hits[3] = {"16"};
pinned_hits[4] = {"11"};
auto pinned_hits = "6:1,1:2,16:3,11:4";
// pinned hits larger than page size: check that pagination works
@ -445,11 +433,43 @@ TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) {
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
}
TEST_F(CollectionOverrideTest, PinnedHitsWhenThereAreNotEnoughResults) {
auto pinned_hits = "6:1,1:2,11:5";
// multiple pinnned hits specified, but query produces no result
auto results = coll_mul_fields->search("notfoundquery", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "starring: will", 30, 5,
"", 10,
pinned_hits, {}).get();
ASSERT_EQ(3, results["found"].get<size_t>());
ASSERT_EQ(3, results["hits"].size());
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get<std::string>().c_str());
// multiple pinned hits but only single result
results = coll_mul_fields->search("burgundy", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "starring: will", 30, 5,
"", 10,
pinned_hits, {}).get();
ASSERT_EQ(4, results["found"].get<size_t>());
ASSERT_EQ(4, results["hits"].size());
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("11", results["hits"][3]["document"]["id"].get<std::string>().c_str());
}
TEST_F(CollectionOverrideTest, PinnedHitsGrouping) {
std::map<size_t, std::vector<std::string>> pinned_hits;
pinned_hits[1] = {"6", "8"};
pinned_hits[2] = {"1"};
pinned_hits[3] = {"13", "4"};
auto pinned_hits = "6:1,8:1,1:2,13:3,4:3";
// without any grouping parameter, only the first ID in a position should be picked
// and other IDs should appear in their original positions
@ -499,4 +519,52 @@ TEST_F(CollectionOverrideTest, PinnedHitsGrouping) {
ASSERT_STREQ("11", results["grouped_hits"][3]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("16", results["grouped_hits"][4]["hits"][0]["document"]["id"].get<std::string>().c_str());
}
TEST_F(CollectionOverrideTest, PinnedHitsIdsHavingColon) {
Collection *coll1;
std::vector<field> fields = {field("url", field_types::STRING, true),
field("points", field_types::INT32, false)};
std::vector<sort_by> sort_fields = { sort_by("points", "DESC") };
coll1 = collectionManager.get_collection("coll1");
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 4, fields, "points").get();
}
for(size_t i=1; i<=10; i++) {
nlohmann::json doc;
doc["id"] = std::string("https://example.com/") + std::to_string(i);
doc["url"] = std::string("https://example.com/") + std::to_string(i);
doc["points"] = i;
coll1->add(doc.dump());
}
std::vector<std::string> query_fields = {"url"};
std::vector<std::string> facets;
std::string pinned_hits_str = "https://example.com/1:1, https://example.com/3:2"; // can have space
auto res_op = coll1->search("*", {"url"}, "", {}, {}, 0, 25, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
pinned_hits_str, {});
ASSERT_TRUE(res_op.ok());
auto res = res_op.get();
ASSERT_EQ(10, res["found"].get<size_t>());
ASSERT_STREQ("https://example.com/1", res["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("https://example.com/3", res["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("https://example.com/10", res["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("https://example.com/9", res["hits"][3]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("https://example.com/2", res["hits"][9]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}