diff --git a/include/topster.h b/include/topster.h index 714aa2ea..c5b582ae 100644 --- a/include/topster.h +++ b/include/topster.h @@ -6,16 +6,15 @@ #include /* -* A bounded max heap that remembers the top-K elements seen so far +* Remembers the max-K elements seen so far using a min-heap */ template struct Topster { uint64_t data[MAX_SIZE]; - uint32_t smallest_index = 0; - uint32_t size = 0; + uint32_t size; + + Topster(): size(0){ - Topster(){ - data[smallest_index]= UINT_MAX; } template inline void swapMe(T& a, T& b) { @@ -37,44 +36,40 @@ struct Topster { } void add(const uint32_t&key, const uint32_t& val){ - uint32_t smallest_key, smallest_value; - unpack(data[smallest_index], smallest_key, smallest_value); - if (size >= MAX_SIZE) { - if(val < smallest_value) { + if(val <= getValueAt(0)) { // when incoming value is less than the smallest in the heap, ignore return; } - data[smallest_index] = pack(key, val); - int i = 0; + data[0] = pack(key, val); + uint32_t i = 0; // sift to maintain heap property while ((2*i+1) < MAX_SIZE) { - int next = 2*i + 1; - if (data[next] < data[next+1]) + uint32_t next = (uint32_t) (2 * i + 1); + if (next+1 < MAX_SIZE && getValueAt(next) > getValueAt(next+1)) { next++; + } - if (data[i] < data[next]) swapMe(data[i], data[next]); - else break; + if (getValueAt(i) > getValueAt(next)) { + swapMe(data[i], data[next]); + } else { + break; + } i = next; } } else { - // keep track of the smallest element's index - if(val < smallest_value) { - smallest_index = size; - } - - // insert at the end of the array, and sift it up to maintain heap property data[size++] = pack(key, val); - for (int i = size - 1; i > 0;) { - int parent = (i-1)/2; - if (data[parent] < data[i]) { + for (uint32_t i = size - 1; i > 0;) { + uint32_t parent = (i-1)/2; + if (getValueAt(parent) > getValueAt(i)) { swapMe(data[parent], data[i]); i = parent; + } else { + break; } - else break; } } } @@ -97,10 +92,17 @@ struct Topster { size = 0; } - uint32_t getKeyAt(uint32_t& index) { + uint32_t getKeyAt(uint32_t index) { uint32_t key; uint32_t value; unpack(data[index], key, value); return key; } + + uint32_t getValueAt(uint32_t index) { + uint32_t key; + uint32_t value; + unpack(data[index], key, value); + return value; + } }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index ec98609f..5a61475b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -99,6 +99,7 @@ std::vector Collection::search(std::string query, const int num_ int cost = 0; size_t total_results = 0; std::vector results; + Topster<100> topster; while(cost <= max_cost) { std::cout << "Searching with cost=" << cost << std::endl; @@ -125,7 +126,6 @@ std::vector Collection::search(std::string query, const int num_ continue; } - Topster<100> topster; const size_t combination_limit = 10; auto product = []( long long a, std::vector& b ) { return a*b.size(); }; long long int N = std::accumulate(token_leaves.begin(), token_leaves.end(), 1LL, product ); @@ -187,24 +187,31 @@ void Collection::score_results(Topster<100> &topster, const std::vector> token_positions; - // for each token in the query, find the positions that it appears in this document - for (art_leaf *token_leaf : query_suggestion) { - std::vector positions; - uint32_t doc_index = token_leaf->values->ids.indexOf(doc_id); - uint32_t start_offset = token_leaf->values->offset_index.at(doc_index); - uint32_t end_offset = (doc_index == token_leaf->values->ids.getLength() - 1) ? - token_leaf->values->offsets.getLength() : - token_leaf->values->offset_index.at(doc_index+1); + MatchScore mscore; - while(start_offset < end_offset) { - positions.push_back((uint16_t) token_leaf->values->offsets.at(start_offset)); - start_offset++; + if(query_suggestion.size() == 1) { + mscore = MatchScore{1, 1}; + } else { + // for each token in the query, find the positions that it appears in this document + for (art_leaf *token_leaf : query_suggestion) { + std::vector positions; + uint32_t doc_index = token_leaf->values->ids.indexOf(doc_id); + uint32_t start_offset = token_leaf->values->offset_index.at(doc_index); + uint32_t end_offset = (doc_index == token_leaf->values->ids.getLength() - 1) ? + token_leaf->values->offsets.getLength() : + token_leaf->values->offset_index.at(doc_index+1); + + while(start_offset < end_offset) { + positions.push_back((uint16_t) token_leaf->values->offsets.at(start_offset)); + start_offset++; + } + + token_positions.push_back(positions); } - token_positions.push_back(positions); + mscore = MatchScore::match_score(doc_id, token_positions); } - MatchScore mscore = MatchScore::match_score(doc_id, token_positions); const uint32_t cumulativeScore = ((uint32_t)(mscore.words_present * 16 + (20 - mscore.distance)) * 64000) + doc_scores.at(doc_id); /*