Support for pinning and hiding hits during query time.

This commit is contained in:
kishorenc 2020-05-21 19:40:27 +05:30
parent 6af35f5de8
commit f1d0f279c7
5 changed files with 156 additions and 7 deletions

View File

@ -154,7 +154,10 @@ private:
void remove_document(nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
void populate_overrides(std::string query, std::map<uint32_t, size_t> & id_pos_map,
void populate_overrides(std::string query,
const std::map<std::string, size_t>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<uint32_t, size_t> & id_pos_map,
std::vector<uint32_t> & included_ids, std::vector<uint32_t> & excluded_ids);
static bool facet_count_compare(const std::pair<uint64_t, facet_count_t>& a,
@ -230,7 +233,9 @@ public:
const std::string & simple_facet_query = "",
const size_t snippet_threshold = 30,
const std::string & highlight_full_fields = "",
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD);
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
const std::map<std::string, size_t>& pinned_hits={},
const std::vector<std::string>& hidden_hits={});
Option<nlohmann::json> get(const std::string & id);

View File

@ -140,6 +140,10 @@ struct StringUtils {
return (*p == 0);
}
static bool is_positive_integer(const std::string& s) {
return !s.empty() && s.find_first_not_of("0123456789") == std::string::npos;
}
// Adapted from: http://stackoverflow.com/a/2845275/131050
static bool is_uint64_t(const std::string &s) {
if(s.empty()) {

View File

@ -281,10 +281,39 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash
}
}
void Collection::populate_overrides(std::string query, std::map<uint32_t, size_t> & id_pos_map,
std::vector<uint32_t> & included_ids, std::vector<uint32_t> & excluded_ids) {
void Collection::populate_overrides(std::string query,
const std::map<std::string, size_t>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<uint32_t, size_t> & id_pos_map,
std::vector<uint32_t> & included_ids,
std::vector<uint32_t> & excluded_ids) {
StringUtils::tolowercase(query);
// NOTE: if pinned or hidden hits are provided, then overrides will be just ignored
if(!pinned_hits.empty()) {
for(const auto & hit: pinned_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.first);
if(seq_id_op.ok()) {
included_ids.push_back(seq_id_op.get());
id_pos_map[seq_id_op.get()] = hit.second;
}
}
}
if(!hidden_hits.empty()) {
for(const auto & hit: hidden_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit);
if(seq_id_op.ok()) {
excluded_ids.push_back(seq_id_op.get());
}
}
}
if(!hidden_hits.empty() || !pinned_hits.empty()) {
return ;
}
for(const auto & override_kv: overrides) {
const auto & override = override_kv.second;
@ -320,12 +349,14 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const std::string & simple_facet_query,
const size_t snippet_threshold,
const std::string & highlight_full_fields,
size_t typo_tokens_threshold ) {
size_t typo_tokens_threshold,
const std::map<std::string, size_t>& pinned_hits,
const std::vector<std::string>& hidden_hits) {
std::vector<uint32_t> included_ids;
std::vector<uint32_t> excluded_ids;
std::map<uint32_t, size_t> id_pos_map;
populate_overrides(query, id_pos_map, included_ids, excluded_ids);
populate_overrides(query, pinned_hits, hidden_hits, id_pos_map, included_ids, excluded_ids);
std::map<uint32_t, std::vector<uint32_t>> index_to_included_ids;
std::map<uint32_t, std::vector<uint32_t>> index_to_excluded_ids;

View File

@ -243,6 +243,9 @@ bool get_search(http_req & req, http_res & res) {
const char *INCLUDE_FIELDS = "include_fields";
const char *EXCLUDE_FIELDS = "exclude_fields";
const char *PINNED_HITS = "pinned_hits";
const char *HIDDEN_HITS = "hidden_hits";
// strings under this length will be fully highlighted, instead of showing a snippet of relevant portion
const char *SNIPPET_THRESHOLD = "snippet_threshold";
@ -385,6 +388,40 @@ bool get_search(http_req & req, http_res & res) {
}
}
std::map<std::string, size_t> pinned_hits;
if(req.params.count(PINNED_HITS) != 0) {
std::vector<std::string> pinned_hits_strs;
StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ",");
for(const std::string & pinned_hits_str: pinned_hits_strs) {
std::vector<std::string> expression_parts;
StringUtils::split(pinned_hits_str, expression_parts, ":");
if(expression_parts.size() != 2) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
if(!StringUtils::is_positive_integer(expression_parts[1])) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
int position = std::stoi(expression_parts[1]);
if(position == 0) {
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
return false;
}
pinned_hits.emplace(expression_parts[0], position);
}
}
std::vector<std::string> hidden_hits;
if(req.params.count(HIDDEN_HITS) != 0) {
StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ",");
}
CollectionManager & collectionManager = CollectionManager::get_instance();
Collection* collection = collectionManager.get_collection(req.params["collection"]);
@ -415,7 +452,9 @@ bool get_search(http_req & req, http_res & res) {
req.params[FACET_QUERY],
static_cast<size_t>(std::stoi(req.params[SNIPPET_THRESHOLD])),
req.params[HIGHLIGHT_FULL_FIELDS],
typo_tokens_threshold
typo_tokens_threshold,
pinned_hits,
hidden_hits
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -251,4 +251,74 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
coll_mul_fields->remove_override("include-rule");
}
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
std::map<std::string, size_t> pinned_hits;
std::vector<std::string> hidden_hits;
pinned_hits["13"] = 1;
pinned_hits["4"] = 2;
// basic pinning
auto results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
"", 10,
pinned_hits, {}).get();
ASSERT_EQ(10, results["found"].get<size_t>());
ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("16", results["hits"][3]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("6", results["hits"][4]["document"]["id"].get<std::string>().c_str());
// both pinning and hiding
hidden_hits = {"11", "16"};
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
"", 10,
pinned_hits, hidden_hits).get();
ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get<std::string>().c_str());
// take precedence over override rules
nlohmann::json override_json_include = {
{"id", "include-rule"},
{
"rule", {
{"query", "the"},
{"match", override_t::MATCH_EXACT}
}
}
};
// trying to include an ID that is also being hidden via `hidden_hits` query param will not work
// as if pinned or hidden hits are provided, overrides will be entirely skipped
override_json_include["includes"] = nlohmann::json::array();
override_json_include["includes"][0] = nlohmann::json::object();
override_json_include["includes"][0]["id"] = "11";
override_json_include["includes"][0]["position"] = 1;
override_t override_include(override_json_include);
coll_mul_fields->add_override(override_include);
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
"", 10,
{}, {hidden_hits}).get();
ASSERT_EQ(8, results["found"].get<size_t>());
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("8", results["hits"][1]["document"]["id"].get<std::string>().c_str());
}