mirror of
https://github.com/typesense/typesense.git
synced 2025-05-17 12:12:35 +08:00
Support for pinning and hiding hits during query time.
This commit is contained in:
parent
6af35f5de8
commit
f1d0f279c7
@ -154,7 +154,10 @@ private:
|
||||
|
||||
void remove_document(nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
|
||||
|
||||
void populate_overrides(std::string query, std::map<uint32_t, size_t> & id_pos_map,
|
||||
void populate_overrides(std::string query,
|
||||
const std::map<std::string, size_t>& pinned_hits,
|
||||
const std::vector<std::string>& hidden_hits,
|
||||
std::map<uint32_t, size_t> & id_pos_map,
|
||||
std::vector<uint32_t> & included_ids, std::vector<uint32_t> & excluded_ids);
|
||||
|
||||
static bool facet_count_compare(const std::pair<uint64_t, facet_count_t>& a,
|
||||
@ -230,7 +233,9 @@ public:
|
||||
const std::string & simple_facet_query = "",
|
||||
const size_t snippet_threshold = 30,
|
||||
const std::string & highlight_full_fields = "",
|
||||
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD);
|
||||
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
|
||||
const std::map<std::string, size_t>& pinned_hits={},
|
||||
const std::vector<std::string>& hidden_hits={});
|
||||
|
||||
Option<nlohmann::json> get(const std::string & id);
|
||||
|
||||
|
@ -140,6 +140,10 @@ struct StringUtils {
|
||||
return (*p == 0);
|
||||
}
|
||||
|
||||
static bool is_positive_integer(const std::string& s) {
|
||||
return !s.empty() && s.find_first_not_of("0123456789") == std::string::npos;
|
||||
}
|
||||
|
||||
// Adapted from: http://stackoverflow.com/a/2845275/131050
|
||||
static bool is_uint64_t(const std::string &s) {
|
||||
if(s.empty()) {
|
||||
|
@ -281,10 +281,39 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash
|
||||
}
|
||||
}
|
||||
|
||||
void Collection::populate_overrides(std::string query, std::map<uint32_t, size_t> & id_pos_map,
|
||||
std::vector<uint32_t> & included_ids, std::vector<uint32_t> & excluded_ids) {
|
||||
void Collection::populate_overrides(std::string query,
|
||||
const std::map<std::string, size_t>& pinned_hits,
|
||||
const std::vector<std::string>& hidden_hits,
|
||||
std::map<uint32_t, size_t> & id_pos_map,
|
||||
std::vector<uint32_t> & included_ids,
|
||||
std::vector<uint32_t> & excluded_ids) {
|
||||
StringUtils::tolowercase(query);
|
||||
|
||||
// NOTE: if pinned or hidden hits are provided, then overrides will be just ignored
|
||||
|
||||
if(!pinned_hits.empty()) {
|
||||
for(const auto & hit: pinned_hits) {
|
||||
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.first);
|
||||
if(seq_id_op.ok()) {
|
||||
included_ids.push_back(seq_id_op.get());
|
||||
id_pos_map[seq_id_op.get()] = hit.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!hidden_hits.empty()) {
|
||||
for(const auto & hit: hidden_hits) {
|
||||
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit);
|
||||
if(seq_id_op.ok()) {
|
||||
excluded_ids.push_back(seq_id_op.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!hidden_hits.empty() || !pinned_hits.empty()) {
|
||||
return ;
|
||||
}
|
||||
|
||||
for(const auto & override_kv: overrides) {
|
||||
const auto & override = override_kv.second;
|
||||
|
||||
@ -320,12 +349,14 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
const std::string & simple_facet_query,
|
||||
const size_t snippet_threshold,
|
||||
const std::string & highlight_full_fields,
|
||||
size_t typo_tokens_threshold ) {
|
||||
size_t typo_tokens_threshold,
|
||||
const std::map<std::string, size_t>& pinned_hits,
|
||||
const std::vector<std::string>& hidden_hits) {
|
||||
|
||||
std::vector<uint32_t> included_ids;
|
||||
std::vector<uint32_t> excluded_ids;
|
||||
std::map<uint32_t, size_t> id_pos_map;
|
||||
populate_overrides(query, id_pos_map, included_ids, excluded_ids);
|
||||
populate_overrides(query, pinned_hits, hidden_hits, id_pos_map, included_ids, excluded_ids);
|
||||
|
||||
std::map<uint32_t, std::vector<uint32_t>> index_to_included_ids;
|
||||
std::map<uint32_t, std::vector<uint32_t>> index_to_excluded_ids;
|
||||
|
@ -243,6 +243,9 @@ bool get_search(http_req & req, http_res & res) {
|
||||
const char *INCLUDE_FIELDS = "include_fields";
|
||||
const char *EXCLUDE_FIELDS = "exclude_fields";
|
||||
|
||||
const char *PINNED_HITS = "pinned_hits";
|
||||
const char *HIDDEN_HITS = "hidden_hits";
|
||||
|
||||
// strings under this length will be fully highlighted, instead of showing a snippet of relevant portion
|
||||
const char *SNIPPET_THRESHOLD = "snippet_threshold";
|
||||
|
||||
@ -385,6 +388,40 @@ bool get_search(http_req & req, http_res & res) {
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, size_t> pinned_hits;
|
||||
if(req.params.count(PINNED_HITS) != 0) {
|
||||
std::vector<std::string> pinned_hits_strs;
|
||||
StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ",");
|
||||
|
||||
for(const std::string & pinned_hits_str: pinned_hits_strs) {
|
||||
std::vector<std::string> expression_parts;
|
||||
StringUtils::split(pinned_hits_str, expression_parts, ":");
|
||||
|
||||
if(expression_parts.size() != 2) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!StringUtils::is_positive_integer(expression_parts[1])) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int position = std::stoi(expression_parts[1]);
|
||||
if(position == 0) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
pinned_hits.emplace(expression_parts[0], position);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> hidden_hits;
|
||||
if(req.params.count(HIDDEN_HITS) != 0) {
|
||||
StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ",");
|
||||
}
|
||||
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
Collection* collection = collectionManager.get_collection(req.params["collection"]);
|
||||
|
||||
@ -415,7 +452,9 @@ bool get_search(http_req & req, http_res & res) {
|
||||
req.params[FACET_QUERY],
|
||||
static_cast<size_t>(std::stoi(req.params[SNIPPET_THRESHOLD])),
|
||||
req.params[HIGHLIGHT_FULL_FIELDS],
|
||||
typo_tokens_threshold
|
||||
typo_tokens_threshold,
|
||||
pinned_hits,
|
||||
hidden_hits
|
||||
);
|
||||
|
||||
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
|
@ -251,4 +251,74 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
|
||||
coll_mul_fields->remove_override("include-rule");
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
|
||||
std::map<std::string, size_t> pinned_hits;
|
||||
std::vector<std::string> hidden_hits;
|
||||
pinned_hits["13"] = 1;
|
||||
pinned_hits["4"] = 2;
|
||||
|
||||
// basic pinning
|
||||
|
||||
auto results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
|
||||
"", 10,
|
||||
pinned_hits, {}).get();
|
||||
|
||||
ASSERT_EQ(10, results["found"].get<size_t>());
|
||||
ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("16", results["hits"][3]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("6", results["hits"][4]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// both pinning and hiding
|
||||
|
||||
hidden_hits = {"11", "16"};
|
||||
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
|
||||
"", 10,
|
||||
pinned_hits, hidden_hits).get();
|
||||
|
||||
ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// take precedence over override rules
|
||||
|
||||
nlohmann::json override_json_include = {
|
||||
{"id", "include-rule"},
|
||||
{
|
||||
"rule", {
|
||||
{"query", "the"},
|
||||
{"match", override_t::MATCH_EXACT}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// trying to include an ID that is also being hidden via `hidden_hits` query param will not work
|
||||
// as if pinned or hidden hits are provided, overrides will be entirely skipped
|
||||
override_json_include["includes"] = nlohmann::json::array();
|
||||
override_json_include["includes"][0] = nlohmann::json::object();
|
||||
override_json_include["includes"][0]["id"] = "11";
|
||||
override_json_include["includes"][0]["position"] = 1;
|
||||
|
||||
override_t override_include(override_json_include);
|
||||
coll_mul_fields->add_override(override_include);
|
||||
|
||||
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, 500, "starring: will", 30,
|
||||
"", 10,
|
||||
{}, {hidden_hits}).get();
|
||||
|
||||
ASSERT_EQ(8, results["found"].get<size_t>());
|
||||
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("8", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user