mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 05:08:43 +08:00
Handle colon character in pinned hit IDs.
This commit is contained in:
parent
254c58dd31
commit
3a86d58506
@ -186,6 +186,9 @@ private:
|
||||
|
||||
Option<bool> parse_filter_query(const std::string& simple_filter_query, std::vector<filter>& filters);
|
||||
|
||||
Option<bool> parse_pinned_hits(const std::string& pinned_hits_str,
|
||||
std::map<size_t, std::vector<std::string>>& pinned_hits);
|
||||
|
||||
public:
|
||||
Collection() = delete;
|
||||
|
||||
@ -254,8 +257,8 @@ public:
|
||||
const size_t highlight_affix_num_tokens = 4,
|
||||
const std::string & highlight_full_fields = "",
|
||||
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
|
||||
const std::map<size_t, std::vector<std::string>>& pinned_hits={},
|
||||
const std::vector<std::string>& hidden_hits={},
|
||||
const std::string& pinned_hits_str="",
|
||||
const std::string& hidden_hits="",
|
||||
const std::vector<std::string>& group_by_fields={},
|
||||
const size_t group_limit = 0,
|
||||
const std::string& highlight_start_tag="<mark>",
|
||||
|
@ -505,8 +505,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
const size_t highlight_affix_num_tokens,
|
||||
const std::string & highlight_full_fields,
|
||||
size_t typo_tokens_threshold,
|
||||
const std::map<size_t, std::vector<std::string>>& pinned_hits,
|
||||
const std::vector<std::string>& hidden_hits,
|
||||
const std::string& pinned_hits_str,
|
||||
const std::string& hidden_hits_str,
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const size_t group_limit,
|
||||
const std::string& highlight_start_tag,
|
||||
@ -523,6 +523,17 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
|
||||
std::vector<uint32_t> excluded_ids;
|
||||
std::map<size_t, std::vector<uint32_t>> include_ids; // position => list of IDs
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
|
||||
Option<bool> pinned_hits_op = parse_pinned_hits(pinned_hits_str, pinned_hits);
|
||||
|
||||
if(!pinned_hits_op.ok()) {
|
||||
return Option<nlohmann::json>(400, pinned_hits_op.error());
|
||||
}
|
||||
|
||||
std::vector<std::string> hidden_hits;
|
||||
StringUtils::split(hidden_hits_str, hidden_hits, ",");
|
||||
|
||||
populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids);
|
||||
|
||||
/*for(auto kv: include_ids) {
|
||||
@ -1810,3 +1821,41 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::parse_pinned_hits(const std::string& pinned_hits_str,
|
||||
std::map<size_t, std::vector<std::string>>& pinned_hits) {
|
||||
if(!pinned_hits_str.empty()) {
|
||||
std::vector<std::string> pinned_hits_strs;
|
||||
StringUtils::split(pinned_hits_str, pinned_hits_strs, ",");
|
||||
|
||||
for(const std::string & pinned_hits_part: pinned_hits_strs) {
|
||||
std::vector<std::string> expression_parts;
|
||||
size_t index = pinned_hits_part.size() - 1;
|
||||
while(index >= 0 && pinned_hits_part[index] != ':') {
|
||||
index--;
|
||||
}
|
||||
|
||||
if(index == 0) {
|
||||
return Option<bool>(false, "Pinned hits are not in expected format.");
|
||||
}
|
||||
|
||||
std::string pinned_id = pinned_hits_part.substr(0, index);
|
||||
std::string pinned_pos = pinned_hits_part.substr(index+1);
|
||||
|
||||
if(!StringUtils::is_positive_integer(pinned_pos)) {
|
||||
return Option<bool>(false, "Pinned hits are not in expected format.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int position = std::stoi(pinned_pos);
|
||||
if(position == 0) {
|
||||
return Option<bool>(false, "Pinned hits must start from position 1.");
|
||||
return false;
|
||||
}
|
||||
|
||||
pinned_hits[position].emplace_back(pinned_id);
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
@ -447,39 +447,12 @@ bool get_search(http_req & req, http_res & res) {
|
||||
}
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
|
||||
if(req.params.count(PINNED_HITS) != 0) {
|
||||
std::vector<std::string> pinned_hits_strs;
|
||||
StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ",");
|
||||
|
||||
for(const std::string & pinned_hits_str: pinned_hits_strs) {
|
||||
std::vector<std::string> expression_parts;
|
||||
StringUtils::split(pinned_hits_str, expression_parts, ":");
|
||||
|
||||
if(expression_parts.size() != 2) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!StringUtils::is_positive_integer(expression_parts[1])) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int position = std::stoi(expression_parts[1]);
|
||||
if(position == 0) {
|
||||
res.set_400(std::string("Parameter `") + PINNED_HITS + "` is malformed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
pinned_hits[position].emplace_back(expression_parts[0]);
|
||||
}
|
||||
if(req.params.count(PINNED_HITS) == 0) {
|
||||
req.params[PINNED_HITS] = "";
|
||||
}
|
||||
|
||||
std::vector<std::string> hidden_hits;
|
||||
if(req.params.count(HIDDEN_HITS) != 0) {
|
||||
StringUtils::split(req.params[HIDDEN_HITS], hidden_hits, ",");
|
||||
if(req.params.count(HIDDEN_HITS) == 0) {
|
||||
req.params[HIDDEN_HITS] = "";
|
||||
}
|
||||
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
@ -513,8 +486,8 @@ bool get_search(http_req & req, http_res & res) {
|
||||
static_cast<size_t>(std::stol(req.params[HIGHLIGHT_AFFIX_NUM_TOKENS])),
|
||||
req.params[HIGHLIGHT_FULL_FIELDS],
|
||||
typo_tokens_threshold,
|
||||
pinned_hits,
|
||||
hidden_hits,
|
||||
req.params[PINNED_HITS],
|
||||
req.params[HIDDEN_HITS],
|
||||
group_by_fields,
|
||||
static_cast<size_t>(std::stol(req.params[GROUP_LIMIT])),
|
||||
req.params[HIGHLIGHT_START_TAG],
|
||||
|
@ -267,9 +267,7 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
pinned_hits[1] = {"13"};
|
||||
pinned_hits[2] = {"4"};
|
||||
auto pinned_hits = "13:1,4:2";
|
||||
|
||||
// basic pinning
|
||||
|
||||
@ -289,8 +287,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
|
||||
|
||||
// both pinning and hiding
|
||||
|
||||
std::vector<std::string> hidden_hits;
|
||||
hidden_hits = {"11", "16"};
|
||||
std::string hidden_hits="11,16";
|
||||
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 50, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
@ -303,9 +300,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
|
||||
ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// paginating such that pinned hits appear on second page
|
||||
pinned_hits.clear();
|
||||
pinned_hits[4] = {"13"};
|
||||
pinned_hits[5] = {"4"};
|
||||
pinned_hits = "13:4,4:5";
|
||||
|
||||
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 2, 2, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
@ -356,10 +351,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) {
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
pinned_hits[1] = {"17"};
|
||||
pinned_hits[4] = {"13"};
|
||||
pinned_hits[3] = {"11"};
|
||||
auto pinned_hits = "17:1,13:4,11:3";
|
||||
|
||||
// pinned hits larger than page size: check that pagination works
|
||||
|
||||
@ -400,11 +392,7 @@ TEST_F(CollectionOverrideTest, PinnedHitsSmallerThanPageSize) {
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) {
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
pinned_hits[1] = {"6"};
|
||||
pinned_hits[2] = {"1"};
|
||||
pinned_hits[3] = {"16"};
|
||||
pinned_hits[4] = {"11"};
|
||||
auto pinned_hits = "6:1,1:2,16:3,11:4";
|
||||
|
||||
// pinned hits larger than page size: check that pagination works
|
||||
|
||||
@ -445,11 +433,43 @@ TEST_F(CollectionOverrideTest, PinnedHitsLargerThanPageSize) {
|
||||
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, PinnedHitsWhenThereAreNotEnoughResults) {
|
||||
auto pinned_hits = "6:1,1:2,11:5";
|
||||
|
||||
// multiple pinnned hits specified, but query produces no result
|
||||
|
||||
auto results = coll_mul_fields->search("notfoundquery", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "starring: will", 30, 5,
|
||||
"", 10,
|
||||
pinned_hits, {}).get();
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// multiple pinned hits but only single result
|
||||
results = coll_mul_fields->search("burgundy", {"title"}, "", {"starring"}, {}, 0, 10, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "starring: will", 30, 5,
|
||||
"", 10,
|
||||
pinned_hits, {}).get();
|
||||
|
||||
ASSERT_EQ(4, results["found"].get<size_t>());
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("11", results["hits"][3]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, PinnedHitsGrouping) {
|
||||
std::map<size_t, std::vector<std::string>> pinned_hits;
|
||||
pinned_hits[1] = {"6", "8"};
|
||||
pinned_hits[2] = {"1"};
|
||||
pinned_hits[3] = {"13", "4"};
|
||||
auto pinned_hits = "6:1,8:1,1:2,13:3,4:3";
|
||||
|
||||
// without any grouping parameter, only the first ID in a position should be picked
|
||||
// and other IDs should appear in their original positions
|
||||
@ -499,4 +519,52 @@ TEST_F(CollectionOverrideTest, PinnedHitsGrouping) {
|
||||
|
||||
ASSERT_STREQ("11", results["grouped_hits"][3]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("16", results["grouped_hits"][4]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, PinnedHitsIdsHavingColon) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("url", field_types::STRING, true),
|
||||
field("points", field_types::INT32, false)};
|
||||
|
||||
std::vector<sort_by> sort_fields = { sort_by("points", "DESC") };
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1");
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 4, fields, "points").get();
|
||||
}
|
||||
|
||||
for(size_t i=1; i<=10; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["id"] = std::string("https://example.com/") + std::to_string(i);
|
||||
doc["url"] = std::string("https://example.com/") + std::to_string(i);
|
||||
doc["points"] = i;
|
||||
|
||||
coll1->add(doc.dump());
|
||||
}
|
||||
|
||||
std::vector<std::string> query_fields = {"url"};
|
||||
std::vector<std::string> facets;
|
||||
|
||||
std::string pinned_hits_str = "https://example.com/1:1, https://example.com/3:2"; // can have space
|
||||
|
||||
auto res_op = coll1->search("*", {"url"}, "", {}, {}, 0, 25, 1, FREQUENCY,
|
||||
false, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10,
|
||||
pinned_hits_str, {});
|
||||
|
||||
ASSERT_TRUE(res_op.ok());
|
||||
|
||||
auto res = res_op.get();
|
||||
|
||||
ASSERT_EQ(10, res["found"].get<size_t>());
|
||||
ASSERT_STREQ("https://example.com/1", res["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("https://example.com/3", res["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("https://example.com/10", res["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("https://example.com/9", res["hits"][3]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("https://example.com/2", res["hits"][9]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user