Return snippets of best matched array elements instead of just the best matched element.

However, we limit the number of snippets returned to 5 for now.
This commit is contained in:
Kishore Nallan 2018-05-22 06:57:18 +05:30
parent 76febf74d0
commit c5b7f3c7e2
3 changed files with 128 additions and 75 deletions

View File

@ -19,13 +19,14 @@
class Collection {
private:
struct highlight_t {
std::string field;
std::string snippet;
std::vector<std::string> snippets;
std::vector<size_t> indices;
uint64_t match_score;
int index;
highlight_t(): match_score(0), index(-1) {
highlight_t() {
}
@ -135,6 +136,8 @@ public:
// strings under this length will be fully highlighted, instead of showing a snippet of relevant portion
enum {SNIPPET_STR_ABOVE_LEN = 30};
enum {MAX_ARRAY_MATCHES = 5};
// Using a $ prefix so that these meta keys stay above record entries in a lexicographically ordered KV store
static constexpr const char* COLLECTION_META_PREFIX = "$CM";
static constexpr const char* COLLECTION_NEXT_SEQ_PREFIX = "$CS";

View File

@ -12,6 +12,24 @@
#include "topster.h"
#include "logger.h"
struct match_index_t {
Match match;
uint64_t match_score = 0;
size_t index;
match_index_t(Match match, uint64_t match_score, size_t index): match(match), match_score(match_score),
index(index) {
}
bool operator<(const match_index_t& a) const {
if(match_score != a.match_score) {
return match_score > a.match_score;
}
return index < a.index;
}
};
Collection::Collection(const std::string name, const uint32_t collection_id, const uint32_t next_seq_id, Store *store,
const std::vector<field> &fields, const std::string & default_sorting_field,
const size_t num_indices):
@ -568,7 +586,9 @@ Option<nlohmann::json> Collection::search(std::string query, const std::vector<s
search_field.type == field_types::STRING_ARRAY)) {
highlight_t highlight;
highlight_result(search_field, searched_queries, field_order_kv, document, string_utils, highlight);
highlights.push_back(highlight);
if(!highlight.snippets.empty()) {
highlights.push_back(highlight);
}
}
}
@ -577,9 +597,11 @@ Option<nlohmann::json> Collection::search(std::string query, const std::vector<s
for(const auto highlight: highlights) {
nlohmann::json h_json = nlohmann::json::object();
h_json["field"] = highlight.field;
h_json["snippet"] = highlight.snippet;
if(highlight.index != -1) {
h_json["index"] = highlight.index;
if(!highlight.indices.empty()) {
h_json["indices"] = highlight.indices;
h_json["snippets"] = highlight.snippets;
} else {
h_json["snippet"] = highlight.snippets[0];
}
wrapper_doc["highlights"].push_back(h_json);
@ -661,9 +683,7 @@ void Collection::highlight_result(const field &search_field,
return ;
}
Match match;
uint64_t match_score = 0;
size_t matched_array_index = 0;
std::vector<match_index_t> match_indices;
for(size_t array_index = 0; array_index < array_token_positions.size(); array_index++) {
const std::vector<std::vector<uint16_t>> & token_positions = array_token_positions[array_index];
@ -674,67 +694,72 @@ void Collection::highlight_result(const field &search_field,
const Match & this_match = Match::match(field_order_kv.key, token_positions);
uint64_t this_match_score = this_match.get_match_score(1, field_order_kv.field_id);
if(this_match_score > match_score) {
match_score = this_match_score;
match = this_match;
matched_array_index = array_index;
}
match_indices.push_back(match_index_t(this_match, this_match_score, array_index));
}
std::vector<std::string> tokens;
if(search_field.type == field_types::STRING) {
StringUtils::split(document[search_field.name], tokens, " ");
} else {
StringUtils::split(document[search_field.name][matched_array_index], tokens, " ");
}
const size_t max_array_matches = std::min((size_t)MAX_ARRAY_MATCHES, match_indices.size());
std::partial_sort(match_indices.begin(), match_indices.begin()+max_array_matches, match_indices.end());
// unpack `match.offset_diffs` into `token_indices`
std::vector<size_t> token_indices;
spp::sparse_hash_set<std::string> token_hits;
for(size_t index = 0; index < max_array_matches; index++) {
const match_index_t & match_index = match_indices[index];
const Match & match = match_index.match;
size_t num_tokens_found = (size_t) match.offset_diffs[0];
for(size_t i = 1; i <= num_tokens_found; i++) {
if(match.offset_diffs[i] != std::numeric_limits<int8_t>::max()) {
size_t token_index = (size_t)(match.start_offset + match.offset_diffs[i]);
token_indices.push_back(token_index);
std::string token = tokens[token_index];
string_utils.unicode_normalize(token);
token_hits.insert(token);
}
}
auto minmax = std::minmax_element(token_indices.begin(), token_indices.end());
// For longer strings, pick surrounding tokens within N tokens of min_index and max_index for the snippet
const size_t start_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? 0 :
std::max(0, (int)(*(minmax.first) - 5));
const size_t end_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? tokens.size() :
std::min((int)tokens.size(), (int)(*(minmax.second) + 5));
std::stringstream snippet_stream;
for(size_t snippet_index = start_index; snippet_index < end_index; snippet_index++) {
if(snippet_index != start_index) {
snippet_stream << " ";
}
std::string token = tokens[snippet_index];
string_utils.unicode_normalize(token);
if(token_hits.count(token) != 0) {
snippet_stream << "<mark>" + tokens[snippet_index] + "</mark>";
std::vector<std::string> tokens;
if(search_field.type == field_types::STRING) {
StringUtils::split(document[search_field.name], tokens, " ");
} else {
snippet_stream << tokens[snippet_index];
StringUtils::split(document[search_field.name][match_index.index], tokens, " ");
}
// unpack `match.offset_diffs` into `token_indices`
std::vector<size_t> token_indices;
spp::sparse_hash_set<std::string> token_hits;
size_t num_tokens_found = (size_t) match.offset_diffs[0];
for(size_t i = 1; i <= num_tokens_found; i++) {
if(match.offset_diffs[i] != std::numeric_limits<int8_t>::max()) {
size_t token_index = (size_t)(match.start_offset + match.offset_diffs[i]);
token_indices.push_back(token_index);
std::string token = tokens[token_index];
string_utils.unicode_normalize(token);
token_hits.insert(token);
}
}
auto minmax = std::minmax_element(token_indices.begin(), token_indices.end());
// For longer strings, pick surrounding tokens within N tokens of min_index and max_index for the snippet
const size_t start_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? 0 :
std::max(0, (int)(*(minmax.first) - 5));
const size_t end_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? tokens.size() :
std::min((int)tokens.size(), (int)(*(minmax.second) + 5));
std::stringstream snippet_stream;
for(size_t snippet_index = start_index; snippet_index < end_index; snippet_index++) {
if(snippet_index != start_index) {
snippet_stream << " ";
}
std::string token = tokens[snippet_index];
string_utils.unicode_normalize(token);
if(token_hits.count(token) != 0) {
snippet_stream << "<mark>" + tokens[snippet_index] + "</mark>";
} else {
snippet_stream << tokens[snippet_index];
}
}
highlight.snippets.push_back(snippet_stream.str());
if(search_field.type == field_types::STRING_ARRAY) {
highlight.indices.push_back(match_index.index);
}
}
highlight.field = search_field.name;
highlight.snippet = snippet_stream.str();
highlight.match_score = match_score;
if(search_field.type == field_types::STRING_ARRAY) {
highlight.index = matched_array_index;
}
highlight.match_score = match_indices[0].match_score;
for (auto it = leaf_to_indices.begin(); it != leaf_to_indices.end(); it++) {
delete [] it->second;

View File

@ -556,9 +556,17 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) {
ASSERT_EQ(results["hits"][0]["highlights"].size(), 1);
ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get<std::string>().c_str(), "tags");
ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get<std::string>().c_str(),
"<mark>truth</mark> <mark>about</mark>");
ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get<size_t>(), 2);
// an array's snippets must be sorted on match score, if match score is same, priority to be given to lower indices
ASSERT_EQ(3, results["hits"][0]["highlights"][0]["snippets"].size());
ASSERT_STREQ("<mark>truth</mark> <mark>about</mark>", results["hits"][0]["highlights"][0]["snippets"][0].get<std::string>().c_str());
ASSERT_STREQ("the <mark>truth</mark>", results["hits"][0]["highlights"][0]["snippets"][1].get<std::string>().c_str());
ASSERT_STREQ("<mark>about</mark> forever", results["hits"][0]["highlights"][0]["snippets"][2].get<std::string>().c_str());
ASSERT_EQ(3, results["hits"][0]["highlights"][0]["indices"].size());
ASSERT_EQ(2, results["hits"][0]["highlights"][0]["indices"][0]);
ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][1]);
ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"][2]);
results = coll_array_text->search("forever truth", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY,
false, 0).get();
@ -574,8 +582,14 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) {
}
ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get<std::string>().c_str(), "tags");
ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get<std::string>().c_str(), "the <mark>truth</mark>");
ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get<size_t>(), 0);
ASSERT_EQ(3, results["hits"][0]["highlights"][0]["snippets"].size());
ASSERT_STREQ("the <mark>truth</mark>", results["hits"][0]["highlights"][0]["snippets"][0].get<std::string>().c_str());
ASSERT_STREQ("about <mark>forever</mark>", results["hits"][0]["highlights"][0]["snippets"][1].get<std::string>().c_str());
ASSERT_STREQ("<mark>truth</mark> about", results["hits"][0]["highlights"][0]["snippets"][2].get<std::string>().c_str());
ASSERT_EQ(3, results["hits"][0]["highlights"][0]["indices"].size());
ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][0]);
ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"][1]);
ASSERT_EQ(2, results["hits"][0]["highlights"][0]["indices"][2]);
results = coll_array_text->search("truth", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY,
false, 0).get();
@ -615,8 +629,13 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) {
ASSERT_EQ(3, results["hits"][0]["highlights"][1].size());
ASSERT_STREQ(results["hits"][0]["highlights"][1]["field"].get<std::string>().c_str(), "tags");
ASSERT_STREQ(results["hits"][0]["highlights"][1]["snippet"].get<std::string>().c_str(), "the <mark>truth</mark>");
ASSERT_EQ(results["hits"][0]["highlights"][1]["index"].get<size_t>(), 0);
ASSERT_EQ(2, results["hits"][0]["highlights"][1]["snippets"].size());
ASSERT_STREQ("the <mark>truth</mark>", results["hits"][0]["highlights"][1]["snippets"][0].get<std::string>().c_str());
ASSERT_STREQ("<mark>truth</mark> about", results["hits"][0]["highlights"][1]["snippets"][1].get<std::string>().c_str());
ASSERT_EQ(2, results["hits"][0]["highlights"][1]["indices"].size());
ASSERT_EQ(0, results["hits"][0]["highlights"][1]["indices"][0]);
ASSERT_EQ(2, results["hits"][0]["highlights"][1]["indices"][1]);
ASSERT_EQ(2, results["hits"][1]["highlights"][0].size());
ASSERT_STREQ(results["hits"][1]["highlights"][0]["field"].get<std::string>().c_str(), "title");
@ -624,8 +643,14 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) {
ASSERT_EQ(3, results["hits"][1]["highlights"][1].size());
ASSERT_STREQ(results["hits"][1]["highlights"][1]["field"].get<std::string>().c_str(), "tags");
ASSERT_STREQ(results["hits"][1]["highlights"][1]["snippet"].get<std::string>().c_str(), "<mark>truth</mark>");
ASSERT_EQ(results["hits"][1]["highlights"][1]["index"].get<size_t>(), 1);
ASSERT_EQ(2, results["hits"][1]["highlights"][1]["snippets"].size());
ASSERT_STREQ("<mark>truth</mark>", results["hits"][1]["highlights"][1]["snippets"][0].get<std::string>().c_str());
ASSERT_STREQ("plain <mark>truth</mark>", results["hits"][1]["highlights"][1]["snippets"][1].get<std::string>().c_str());
ASSERT_EQ(2, results["hits"][1]["highlights"][1]["indices"].size());
ASSERT_EQ(1, results["hits"][1]["highlights"][1]["indices"][0]);
ASSERT_EQ(2, results["hits"][1]["highlights"][1]["indices"][1]);
// highlight fields must be ordered based on match score
results = coll_array_text->search("amazing movie", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY,
@ -634,10 +659,10 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) {
ASSERT_EQ(2, results["hits"][0]["highlights"].size());
ASSERT_EQ(3, results["hits"][0]["highlights"][0].size());
ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get<std::string>().c_str(), "tags");
ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get<std::string>().c_str(),
"<mark>amazing</mark> <mark>movie</mark>");
ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get<size_t>(), 0);
ASSERT_STREQ("tags", results["hits"][0]["highlights"][0]["field"].get<std::string>().c_str());
ASSERT_STREQ("<mark>amazing</mark> <mark>movie</mark>", results["hits"][0]["highlights"][0]["snippets"][0].get<std::string>().c_str());
ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"].size());
ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][0]);
ASSERT_EQ(2, results["hits"][0]["highlights"][1].size());
ASSERT_STREQ(results["hits"][0]["highlights"][1]["field"].get<std::string>().c_str(), "title");