Speed up match scoring.

This commit is contained in:
kishorenc 2020-08-25 19:51:44 +05:30
parent 2f21819db0
commit 59dc1fef0d
4 changed files with 186 additions and 203 deletions

View File

@ -17,185 +17,187 @@
const size_t WINDOW_SIZE = 10;
const uint16_t MAX_DISPLACEMENT = std::numeric_limits<uint16_t>::max();
const uint16_t MAX_TOKENS_DISTANCE = 100;
struct TokenOffset {
uint8_t token_id; // token identifier
uint16_t offset; // token's offset in the text
uint16_t offset_index; // index of the offset in the vector
uint16_t offset_index; // index of the offset in the offset vector
bool operator() (const TokenOffset& a, const TokenOffset& b) {
bool operator()(const TokenOffset &a, const TokenOffset &b) {
return a.offset > b.offset;
}
bool operator>(const TokenOffset &a) const {
return offset > a.offset;
}
};
struct Match {
uint8_t words_present;
uint8_t distance;
uint16_t start_offset;
char offset_diffs[16];
uint8_t words_present;
uint8_t distance;
std::vector<TokenOffset> offsets;
Match(): words_present(0), distance(0), start_offset(0) {
Match() : words_present(0), distance(0) {
}
Match(uint8_t words_present, uint8_t distance, uint16_t start_offset, char *offset_diffs_stacked):
words_present(words_present), distance(distance), start_offset(start_offset) {
memcpy(offset_diffs, offset_diffs_stacked, 16);
}
// Construct a single match score from individual components (for multi-field sort)
inline uint64_t get_match_score(const uint32_t total_cost, const uint8_t field_id) const {
uint64_t match_score = ((int64_t)(words_present) << 24) |
((int64_t)(255 - total_cost) << 16) |
((int64_t)(distance) << 8) |
((int64_t)(field_id));
return match_score;
}
static void print_token_offsets(std::vector<std::vector<uint16_t>> &token_offsets) {
for(auto offsets: token_offsets) {
for(auto offset: offsets) {
LOG(INFO) << offset << ", ";
}
LOG(INFO) << "";
}
}
static inline void addTopOfHeapToWindow(TokenOffsetHeap &heap, std::queue<TokenOffset> &window,
const std::vector<std::vector<uint16_t>> &token_offsets,
uint16_t *token_offset) {
TokenOffset top = heap.top();
heap.pop();
window.push(top);
token_offset[top.token_id] = std::min(token_offset[top.token_id], top.offset);
top.offset_index++;
Match(uint8_t words_present, uint8_t distance) : words_present(words_present), distance(distance) {
// Must refill the heap - push the next offset of the same token
if(top.offset_index < token_offsets[top.token_id].size()) {
heap.push(TokenOffset{top.token_id, token_offsets[top.token_id][top.offset_index], top.offset_index});
}
}
static void pack_token_offsets(const uint16_t* min_token_offset, const size_t num_tokens,
const uint16_t token_start_offset, char *offset_diffs) {
offset_diffs[0] = (char) num_tokens;
size_t j = 1;
// Construct a single match score from individual components (for multi-field sort)
inline uint64_t get_match_score(const uint32_t total_cost, const uint8_t field_id) const {
uint64_t match_score = ((int64_t) (words_present) << 24) |
((int64_t) (255 - total_cost) << 16) |
((int64_t) (distance) << 8) |
((int64_t) (field_id));
return match_score;
}
for(size_t i = 0; i < num_tokens; i++) {
if(min_token_offset[i] != MAX_DISPLACEMENT) {
offset_diffs[j] = (int8_t)(min_token_offset[i] - token_start_offset);
static void print_token_offsets(std::vector<std::vector<uint16_t>> &token_offsets) {
for (auto offsets: token_offsets) {
for (auto offset: offsets) {
LOG(INFO) << offset << ", ";
}
LOG(INFO) << "";
}
}
template<typename T>
void sort3(std::vector<T>& a) {
if (a[0] > a[1]) {
if (a[1] > a[2]) {
return;
} else if (a[0] > a[2]) {
std::swap(a[1], a[2]);
} else {
T tmp = std::move(a[0]);
a[0] = std::move(a[2]);
a[2] = std::move(a[1]);
a[1] = std::move(tmp);
}
} else {
offset_diffs[j] = std::numeric_limits<int8_t>::max();
if (a[0] > a[2]) {
std::swap(a[0], a[1]);
} else if (a[2] > a[1]) {
std::swap(a[0], a[2]);
} else {
T tmp = std::move(a[0]);
a[0] = std::move(a[1]);
a[1] = std::move(a[2]);
a[2] = std::move(tmp);
}
}
j++;
}
}
/*
* Given *sorted offsets* of each target token in a *single* document (token_offsets), generates a score indicating:
* a) How many tokens are present within a match window
* b) The proximity between the tokens within the match window
*
* We use a priority queue to read the offset vectors in a sorted manner, slide a window of a given size, and
* compute the max_match and min_displacement of target tokens across the windows.
*/
static Match match(uint32_t doc_id, const std::vector<std::vector<uint16_t>> & token_offsets) {
std::priority_queue<TokenOffset, std::vector<TokenOffset>, TokenOffset> heap;
const size_t tokens_size = std::min(token_offsets.size(), WINDOW_SIZE);
for(uint8_t token_id=0; token_id < tokens_size; token_id++) {
heap.push(TokenOffset{token_id, token_offsets[token_id].front(), 0});
}
// heap now contains the first occurring offset of each token in the given document
/*
Given *sorted offsets* of each target token in a *single* document (token_offsets), generates a score indicating:
a) How many tokens are present within a match window
b) The proximity between the tokens within the match window
uint16_t max_match = 0;
uint16_t min_displacement = MAX_DISPLACEMENT;
How it works:
------------
Create vector with first offset from each token.
Sort vector descending.
Calculate distance, use only tokens within max window size from lowest offset.
Reassign best distance and window if found.
Pop end of vector (smallest offset).
Push to vector next offset of token just popped.
Until queue size is 1.
*/
std::queue<TokenOffset> window;
uint16_t token_offset[WINDOW_SIZE] = { };
std::fill_n(token_offset, WINDOW_SIZE, MAX_DISPLACEMENT);
Match(uint32_t doc_id, const std::vector<std::vector<uint16_t>> &token_offsets, bool populate_window=true) {
// in case if number of tokens in query is greater than max window
const size_t tokens_size = std::min(token_offsets.size(), WINDOW_SIZE);
// used to store token offsets of the best-matched window
uint16_t min_token_offset[WINDOW_SIZE];
std::fill_n(min_token_offset, WINDOW_SIZE, MAX_DISPLACEMENT);
do {
if(window.empty()) {
addTopOfHeapToWindow(heap, window, token_offsets, token_offset);
}
D(LOG(INFO) << "Loop till window fills... doc_id: " << doc_id;)
// Fill the queue with tokens within a given window frame size of the start offset
// At the same time, we also record the *last* occurrence of each token within the window
// For e.g. if `cat` appeared at offsets 1,3 and 5, we will record `token_offset[cat] = 5`
const uint16_t start_offset = window.front().offset;
while(!heap.empty() && heap.top().offset < start_offset+WINDOW_SIZE) {
addTopOfHeapToWindow(heap, window, token_offsets, token_offset);
}
D(LOG(INFO) << "----");
uint16_t prev_pos = MAX_DISPLACEMENT;
uint16_t num_match = 0;
uint16_t displacement = 0;
for(size_t token_id=0; token_id<tokens_size; token_id++) {
// If a token appeared within the window, we would have recorded its offset
if(token_offset[token_id] != MAX_DISPLACEMENT) {
num_match++;
if(prev_pos == MAX_DISPLACEMENT) { // for the first word
prev_pos = token_offset[token_id];
displacement = 0;
} else {
// Calculate the distance between the tokens.
// This will be 0 when all the tokens are adjacent to each other
D(LOG(INFO) << "prev_pos: " << prev_pos << " , curr_pos: " << token_offset[token_id]);
displacement += abs(token_offset[token_id]-prev_pos);
prev_pos = token_offset[token_id];
}
std::vector<TokenOffset> window(tokens_size);
for (size_t token_id = 0; token_id < tokens_size; token_id++) {
window[token_id] = TokenOffset{static_cast<uint8_t>(token_id), token_offsets[token_id][0], 0};
}
}
// Normalize displacement such that matches of same length always have the same displacement
// Ensure that displacement is > 0 -- happens if tokens repeat (displacement will be 0 but num_match > 1)
displacement = std::max(0, int16_t(displacement) - num_match + 1);
std::vector<TokenOffset> best_window;
if(populate_window) {
best_window = window;
}
D(LOG(INFO) << std::endl << "!!!displacement: " << displacement << " | num_match: " << num_match);
size_t best_num_match = 1;
size_t best_displacement = MAX_DISPLACEMENT;
// Track the best `num_match` and `displacement` (in that order) seen so far across all the windows
if(num_match > max_match || (num_match == max_match && displacement < min_displacement)) {
min_displacement = displacement;
// record the token positions (for highlighting)
memcpy(min_token_offset, token_offset, tokens_size*sizeof(uint16_t));
max_match = num_match;
}
while (window.size() > 1) {
if(window.size() == 3) {
sort3<TokenOffset>(window);
} else {
std::sort(window.begin(), window.end(), std::greater<TokenOffset>()); // descending comparator
}
// As we slide the window, drop the first token of the window from the computation
token_offset[window.front().token_id] = MAX_DISPLACEMENT;
window.pop();
} while(!heap.empty());
size_t min_offset = window.back().offset;
// do run-length encoding of the min token positions/offsets
uint16_t token_start_offset = 0;
char packed_offset_diffs[16];
std::fill_n(packed_offset_diffs, 16, 0);
size_t this_displacement = 0;
size_t this_num_match = 0;
std::vector<TokenOffset> this_window(tokens_size);
// identify the first token which is actually present and use that as the base for run-length encoding
size_t token_index = 0;
while(token_index < tokens_size) {
if(min_token_offset[token_index] != MAX_DISPLACEMENT) {
token_start_offset = min_token_offset[token_index];
break;
}
token_index++;
for (size_t i = 0; i < window.size(); i++) {
if(populate_window) {
this_window[window[i].token_id] = window[i];
}
if ((window[i].offset - min_offset) <= WINDOW_SIZE) {
uint16_t next_offset = (i == window.size() - 1) ? window[i].offset : window[i + 1].offset;
this_displacement += window[i].offset - next_offset;
this_num_match++;
} else {
// to indicate that this offset should not be considered
if(populate_window) {
this_window[window[i].token_id].offset = MAX_DISPLACEMENT;
}
}
}
if(populate_window) {
this_window[window.back().token_id] = window.back();
}
if (this_num_match > best_num_match && this_displacement < best_displacement) {
best_displacement = this_displacement;
best_num_match = this_num_match;
if(populate_window) {
best_window = this_window;
}
}
if (best_displacement == (window.size() - 1)) {
// this is the best we can get, so quit early!
break;
}
// fill window with next possible smallest offset across available token this_token_offsets
const TokenOffset &smallest_offset = window.back();
window.pop_back();
const uint8_t token_id = smallest_offset.token_id;
const std::vector<uint16_t> &this_token_offsets = token_offsets[token_id];
if (smallest_offset.offset == this_token_offsets.back()) {
// no more offsets for this token
continue;
}
// Push next offset of same token popped
uint16_t next_offset_index = (smallest_offset.offset_index + 1);
TokenOffset token_offset{token_id, this_token_offsets[next_offset_index], next_offset_index};
window.emplace_back(token_offset);
}
if (best_displacement == MAX_DISPLACEMENT) {
best_displacement = 0;
}
uint8_t best_distance = uint8_t(100 - best_displacement);
words_present = best_num_match;
distance = best_distance;
if(populate_window) {
offsets = best_window;
}
}
const uint8_t distance = MAX_TOKENS_DISTANCE - min_displacement;
pack_token_offsets(min_token_offset, tokens_size, token_start_offset, packed_offset_diffs);
return Match(max_match, distance, token_start_offset, packed_offset_diffs);
}
};

View File

@ -1233,7 +1233,7 @@ void Collection::highlight_result(const field &search_field,
continue;
}
const Match & this_match = Match::match(field_order_kv->key, token_positions);
const Match & this_match = Match(field_order_kv->key, token_positions);
uint64_t this_match_score = this_match.get_match_score(1, field_order_kv->field_id);
match_indices.emplace_back(this_match, this_match_score, array_index);
}
@ -1258,14 +1258,12 @@ void Collection::highlight_result(const field &search_field,
StringUtils::split(document[search_field.name][match_index.index], tokens, " ");
}
// unpack `match.offset_diffs` into `token_indices`
std::vector<size_t> token_indices;
spp::sparse_hash_set<std::string> token_hits;
size_t num_tokens_found = (size_t) match.offset_diffs[0];
for(size_t i = 1; i <= num_tokens_found; i++) {
if(match.offset_diffs[i] != std::numeric_limits<int8_t>::max()) {
size_t token_index = (size_t)(match.start_offset + match.offset_diffs[i]);
for(size_t i = 0; i < match.words_present; i++) {
if(match.offsets[i].offset != MAX_DISPLACEMENT) {
size_t token_index = (size_t)(match.offsets[i].offset);
token_indices.push_back(token_index);
std::string token = tokens[token_index];
string_utils.unicode_normalize(token);

View File

@ -1442,9 +1442,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
//auto begin = std::chrono::high_resolution_clock::now();
char empty_offset_diffs[16];
std::fill_n(empty_offset_diffs, 16, 0);
Match single_token_match = Match(1, 0, 0, empty_offset_diffs);
Match single_token_match = Match(1, 0);
const uint64_t single_token_match_score = single_token_match.get_match_score(total_cost, field_id);
std::unordered_map<std::string, size_t> facet_to_id;
@ -1472,7 +1470,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
if(token_positions.empty()) {
continue;
}
const Match & match = Match::match(seq_id, token_positions);
const Match & match = Match(seq_id, token_positions, false);
uint64_t this_match_score = match.get_match_score(total_cost, field_id);
if(this_match_score > match_score) {

View File

@ -1,49 +1,6 @@
#include <gtest/gtest.h>
#include <match_score.h>
TEST(MatchTest, ShouldPackTokenOffsets) {
uint16_t min_token_offset1[3] = {567, 568, 570};
char offset_diffs[16];
Match::pack_token_offsets(min_token_offset1, 3, 567, offset_diffs);
ASSERT_EQ(3, offset_diffs[0]);
ASSERT_EQ(0, offset_diffs[1]);
ASSERT_EQ(1, offset_diffs[2]);
ASSERT_EQ(3, offset_diffs[3]);
uint16_t min_token_offset2[3] = {0, 1, 2};
Match::pack_token_offsets(min_token_offset2, 3, 0, offset_diffs);
ASSERT_EQ(3, offset_diffs[0]);
ASSERT_EQ(0, offset_diffs[1]);
ASSERT_EQ(1, offset_diffs[2]);
ASSERT_EQ(2, offset_diffs[3]);
uint16_t min_token_offset3[1] = {123};
Match::pack_token_offsets(min_token_offset3, 1, 123, offset_diffs);
ASSERT_EQ(1, offset_diffs[0]);
ASSERT_EQ(0, offset_diffs[1]);
// a token might not have an offset because it might not be in the best matching window
uint16_t min_token_offset4[3] = {0, MAX_DISPLACEMENT, 2};
Match::pack_token_offsets(min_token_offset4, 3, 0, offset_diffs);
ASSERT_EQ(3, offset_diffs[0]);
ASSERT_EQ(0, offset_diffs[1]);
ASSERT_EQ(std::numeric_limits<int8_t>::max(), offset_diffs[2]);
ASSERT_EQ(2, offset_diffs[3]);
uint16_t min_token_offset5[3] = {MAX_DISPLACEMENT, 2, 4};
Match::pack_token_offsets(min_token_offset5, 3, 2, offset_diffs);
ASSERT_EQ(3, offset_diffs[0]);
ASSERT_EQ(std::numeric_limits<int8_t>::max(), offset_diffs[1]);
ASSERT_EQ(0, offset_diffs[2]);
ASSERT_EQ(2, offset_diffs[3]);
}
TEST(MatchTest, TokenOffsetsExceedWindowSize) {
std::vector<std::vector<uint16_t>> token_positions = {
std::vector<uint16_t>({1}), std::vector<uint16_t>({1}), std::vector<uint16_t>({1}), std::vector<uint16_t>({1}),
@ -51,7 +8,35 @@ TEST(MatchTest, TokenOffsetsExceedWindowSize) {
std::vector<uint16_t>({1}), std::vector<uint16_t>({1}), std::vector<uint16_t>({1}), std::vector<uint16_t>({1})
};
const Match & this_match = Match::match(100, token_positions);
const Match & this_match = Match(100, token_positions);
ASSERT_EQ(WINDOW_SIZE, (size_t)this_match.words_present);
}
TEST(MatchTest, MatchScoreV2) {
std::vector<std::vector<uint16_t>> token_offsets;
token_offsets.push_back({38, 50, 170, 187, 195, 222});
token_offsets.push_back({39, 140, 171, 189, 223});
token_offsets.push_back({169, 180});
// token_offsets.push_back({38, 50, 187, 195, 201});
// token_offsets.push_back({120, 167, 171, 223}); // 39,
// token_offsets.push_back({240, 250});
size_t total_distance = 0, words_present = 0, offset_sum = 0;
auto begin = std::chrono::high_resolution_clock::now();
for(size_t i = 0; i < 1; i++) {
auto match = Match(100, token_offsets, true);
total_distance += match.distance;
words_present += match.words_present;
offset_sum += match.offsets.size();
}
uint64_t timeNanos = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();
LOG(INFO) << "Time taken: " << timeNanos;
LOG(INFO) << total_distance << ", " << words_present << ", " << offset_sum;
ASSERT_EQ(WINDOW_SIZE, this_match.words_present);
}