diff --git a/include/collection.h b/include/collection.h index 4e411549..d59a2e1d 100644 --- a/include/collection.h +++ b/include/collection.h @@ -19,13 +19,14 @@ class Collection { private: + struct highlight_t { std::string field; - std::string snippet; + std::vector snippets; + std::vector indices; uint64_t match_score; - int index; - highlight_t(): match_score(0), index(-1) { + highlight_t() { } @@ -135,6 +136,8 @@ public: // strings under this length will be fully highlighted, instead of showing a snippet of relevant portion enum {SNIPPET_STR_ABOVE_LEN = 30}; + enum {MAX_ARRAY_MATCHES = 5}; + // Using a $ prefix so that these meta keys stay above record entries in a lexicographically ordered KV store static constexpr const char* COLLECTION_META_PREFIX = "$CM"; static constexpr const char* COLLECTION_NEXT_SEQ_PREFIX = "$CS"; diff --git a/src/collection.cpp b/src/collection.cpp index 44bee62e..74bd502e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -12,6 +12,24 @@ #include "topster.h" #include "logger.h" +struct match_index_t { + Match match; + uint64_t match_score = 0; + size_t index; + + match_index_t(Match match, uint64_t match_score, size_t index): match(match), match_score(match_score), + index(index) { + + } + + bool operator<(const match_index_t& a) const { + if(match_score != a.match_score) { + return match_score > a.match_score; + } + return index < a.index; + } +}; + Collection::Collection(const std::string name, const uint32_t collection_id, const uint32_t next_seq_id, Store *store, const std::vector &fields, const std::string & default_sorting_field, const size_t num_indices): @@ -568,7 +586,9 @@ Option Collection::search(std::string query, const std::vector Collection::search(std::string query, const std::vector match_indices; for(size_t array_index = 0; array_index < array_token_positions.size(); array_index++) { const std::vector> & token_positions = array_token_positions[array_index]; @@ -674,67 +694,72 @@ void Collection::highlight_result(const field &search_field, const Match & this_match = Match::match(field_order_kv.key, token_positions); uint64_t this_match_score = this_match.get_match_score(1, field_order_kv.field_id); - if(this_match_score > match_score) { - match_score = this_match_score; - match = this_match; - matched_array_index = array_index; - } + match_indices.push_back(match_index_t(this_match, this_match_score, array_index)); } - std::vector tokens; - if(search_field.type == field_types::STRING) { - StringUtils::split(document[search_field.name], tokens, " "); - } else { - StringUtils::split(document[search_field.name][matched_array_index], tokens, " "); - } + const size_t max_array_matches = std::min((size_t)MAX_ARRAY_MATCHES, match_indices.size()); + std::partial_sort(match_indices.begin(), match_indices.begin()+max_array_matches, match_indices.end()); - // unpack `match.offset_diffs` into `token_indices` - std::vector token_indices; - spp::sparse_hash_set token_hits; + for(size_t index = 0; index < max_array_matches; index++) { + const match_index_t & match_index = match_indices[index]; + const Match & match = match_index.match; - size_t num_tokens_found = (size_t) match.offset_diffs[0]; - for(size_t i = 1; i <= num_tokens_found; i++) { - if(match.offset_diffs[i] != std::numeric_limits::max()) { - size_t token_index = (size_t)(match.start_offset + match.offset_diffs[i]); - token_indices.push_back(token_index); - std::string token = tokens[token_index]; - string_utils.unicode_normalize(token); - token_hits.insert(token); - } - } - - auto minmax = std::minmax_element(token_indices.begin(), token_indices.end()); - - // For longer strings, pick surrounding tokens within N tokens of min_index and max_index for the snippet - const size_t start_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? 0 : - std::max(0, (int)(*(minmax.first) - 5)); - - const size_t end_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? tokens.size() : - std::min((int)tokens.size(), (int)(*(minmax.second) + 5)); - - std::stringstream snippet_stream; - for(size_t snippet_index = start_index; snippet_index < end_index; snippet_index++) { - if(snippet_index != start_index) { - snippet_stream << " "; - } - - std::string token = tokens[snippet_index]; - string_utils.unicode_normalize(token); - - if(token_hits.count(token) != 0) { - snippet_stream << "" + tokens[snippet_index] + ""; + std::vector tokens; + if(search_field.type == field_types::STRING) { + StringUtils::split(document[search_field.name], tokens, " "); } else { - snippet_stream << tokens[snippet_index]; + StringUtils::split(document[search_field.name][match_index.index], tokens, " "); + } + + // unpack `match.offset_diffs` into `token_indices` + std::vector token_indices; + spp::sparse_hash_set token_hits; + + size_t num_tokens_found = (size_t) match.offset_diffs[0]; + for(size_t i = 1; i <= num_tokens_found; i++) { + if(match.offset_diffs[i] != std::numeric_limits::max()) { + size_t token_index = (size_t)(match.start_offset + match.offset_diffs[i]); + token_indices.push_back(token_index); + std::string token = tokens[token_index]; + string_utils.unicode_normalize(token); + token_hits.insert(token); + } + } + + auto minmax = std::minmax_element(token_indices.begin(), token_indices.end()); + + // For longer strings, pick surrounding tokens within N tokens of min_index and max_index for the snippet + const size_t start_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? 0 : + std::max(0, (int)(*(minmax.first) - 5)); + + const size_t end_index = (tokens.size() <= SNIPPET_STR_ABOVE_LEN) ? tokens.size() : + std::min((int)tokens.size(), (int)(*(minmax.second) + 5)); + + std::stringstream snippet_stream; + for(size_t snippet_index = start_index; snippet_index < end_index; snippet_index++) { + if(snippet_index != start_index) { + snippet_stream << " "; + } + + std::string token = tokens[snippet_index]; + string_utils.unicode_normalize(token); + + if(token_hits.count(token) != 0) { + snippet_stream << "" + tokens[snippet_index] + ""; + } else { + snippet_stream << tokens[snippet_index]; + } + } + + highlight.snippets.push_back(snippet_stream.str()); + + if(search_field.type == field_types::STRING_ARRAY) { + highlight.indices.push_back(match_index.index); } } highlight.field = search_field.name; - highlight.snippet = snippet_stream.str(); - highlight.match_score = match_score; - - if(search_field.type == field_types::STRING_ARRAY) { - highlight.index = matched_array_index; - } + highlight.match_score = match_indices[0].match_score; for (auto it = leaf_to_indices.begin(); it != leaf_to_indices.end(); it++) { delete [] it->second; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 1415949f..665918eb 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -556,9 +556,17 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) { ASSERT_EQ(results["hits"][0]["highlights"].size(), 1); ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get().c_str(), "tags"); - ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get().c_str(), - "truth about"); - ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get(), 2); + + // an array's snippets must be sorted on match score, if match score is same, priority to be given to lower indices + ASSERT_EQ(3, results["hits"][0]["highlights"][0]["snippets"].size()); + ASSERT_STREQ("truth about", results["hits"][0]["highlights"][0]["snippets"][0].get().c_str()); + ASSERT_STREQ("the truth", results["hits"][0]["highlights"][0]["snippets"][1].get().c_str()); + ASSERT_STREQ("about forever", results["hits"][0]["highlights"][0]["snippets"][2].get().c_str()); + + ASSERT_EQ(3, results["hits"][0]["highlights"][0]["indices"].size()); + ASSERT_EQ(2, results["hits"][0]["highlights"][0]["indices"][0]); + ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][1]); + ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"][2]); results = coll_array_text->search("forever truth", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, false, 0).get(); @@ -574,8 +582,14 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) { } ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get().c_str(), "tags"); - ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get().c_str(), "the truth"); - ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get(), 0); + ASSERT_EQ(3, results["hits"][0]["highlights"][0]["snippets"].size()); + ASSERT_STREQ("the truth", results["hits"][0]["highlights"][0]["snippets"][0].get().c_str()); + ASSERT_STREQ("about forever", results["hits"][0]["highlights"][0]["snippets"][1].get().c_str()); + ASSERT_STREQ("truth about", results["hits"][0]["highlights"][0]["snippets"][2].get().c_str()); + ASSERT_EQ(3, results["hits"][0]["highlights"][0]["indices"].size()); + ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][0]); + ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"][1]); + ASSERT_EQ(2, results["hits"][0]["highlights"][0]["indices"][2]); results = coll_array_text->search("truth", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, false, 0).get(); @@ -615,8 +629,13 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) { ASSERT_EQ(3, results["hits"][0]["highlights"][1].size()); ASSERT_STREQ(results["hits"][0]["highlights"][1]["field"].get().c_str(), "tags"); - ASSERT_STREQ(results["hits"][0]["highlights"][1]["snippet"].get().c_str(), "the truth"); - ASSERT_EQ(results["hits"][0]["highlights"][1]["index"].get(), 0); + ASSERT_EQ(2, results["hits"][0]["highlights"][1]["snippets"].size()); + ASSERT_STREQ("the truth", results["hits"][0]["highlights"][1]["snippets"][0].get().c_str()); + ASSERT_STREQ("truth about", results["hits"][0]["highlights"][1]["snippets"][1].get().c_str()); + + ASSERT_EQ(2, results["hits"][0]["highlights"][1]["indices"].size()); + ASSERT_EQ(0, results["hits"][0]["highlights"][1]["indices"][0]); + ASSERT_EQ(2, results["hits"][0]["highlights"][1]["indices"][1]); ASSERT_EQ(2, results["hits"][1]["highlights"][0].size()); ASSERT_STREQ(results["hits"][1]["highlights"][0]["field"].get().c_str(), "title"); @@ -624,8 +643,14 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) { ASSERT_EQ(3, results["hits"][1]["highlights"][1].size()); ASSERT_STREQ(results["hits"][1]["highlights"][1]["field"].get().c_str(), "tags"); - ASSERT_STREQ(results["hits"][1]["highlights"][1]["snippet"].get().c_str(), "truth"); - ASSERT_EQ(results["hits"][1]["highlights"][1]["index"].get(), 1); + + ASSERT_EQ(2, results["hits"][1]["highlights"][1]["snippets"].size()); + ASSERT_STREQ("truth", results["hits"][1]["highlights"][1]["snippets"][0].get().c_str()); + ASSERT_STREQ("plain truth", results["hits"][1]["highlights"][1]["snippets"][1].get().c_str()); + + ASSERT_EQ(2, results["hits"][1]["highlights"][1]["indices"].size()); + ASSERT_EQ(1, results["hits"][1]["highlights"][1]["indices"][0]); + ASSERT_EQ(2, results["hits"][1]["highlights"][1]["indices"][1]); // highlight fields must be ordered based on match score results = coll_array_text->search("amazing movie", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, @@ -634,10 +659,10 @@ TEST_F(CollectionTest, ArrayStringFieldHighlight) { ASSERT_EQ(2, results["hits"][0]["highlights"].size()); ASSERT_EQ(3, results["hits"][0]["highlights"][0].size()); - ASSERT_STREQ(results["hits"][0]["highlights"][0]["field"].get().c_str(), "tags"); - ASSERT_STREQ(results["hits"][0]["highlights"][0]["snippet"].get().c_str(), - "amazing movie"); - ASSERT_EQ(results["hits"][0]["highlights"][0]["index"].get(), 0); + ASSERT_STREQ("tags", results["hits"][0]["highlights"][0]["field"].get().c_str()); + ASSERT_STREQ("amazing movie", results["hits"][0]["highlights"][0]["snippets"][0].get().c_str()); + ASSERT_EQ(1, results["hits"][0]["highlights"][0]["indices"].size()); + ASSERT_EQ(0, results["hits"][0]["highlights"][0]["indices"][0]); ASSERT_EQ(2, results["hits"][0]["highlights"][1].size()); ASSERT_STREQ(results["hits"][0]["highlights"][1]["field"].get().c_str(), "title");