Reciprocal Rank Fusion

This commit is contained in:
ozanarmagan 2023-03-13 17:15:25 +03:00
parent bf8ded2bac
commit a0d1b74579
2 changed files with 40 additions and 7 deletions

View File

@ -3034,20 +3034,47 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
}
std::vector<std::pair<uint32_t,float>> vec_results;
for (const auto& dist_label : dist_labels) {
uint32 seq_id = dist_label.second;
uint32_t seq_id = dist_label.second;
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
auto score = (1.0 - vec_dist_score) * 100000000000.0;
auto found = topster->kv_map.find(seq_id);
if (found != topster->kv_map.end() && found->second->match_score_index >= 0 && found->second->match_score_index <= 2) {
found->second->scores[found->second->match_score_index] += score;
auto score = (1.0 - vec_dist_score) * 100.0;
vec_results.emplace_back(seq_id, score);
}
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
return a.second > b.second;
});
topster->sort();
// Reciprocal rank fusion
for(uint32_t i = 0; i < topster->size; i++) {
auto result = topster->getKV(i);
if(result->match_score_index < 0 || result->match_score_index > 2) {
continue;
}
result->scores[result->match_score_index] = (1.0 / (i + 1)) * INT64_MAX * 0.7;
}
for(int i = 0; i < vec_results.size(); i++) {
auto& result = vec_results[i];
auto doc_id = result.first;
auto result_it = topster->kv_map.find(doc_id);
if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) {
auto result = result_it->second;
result->scores[result->match_score_index] += (1.0 / (i + 1)) * INT64_MAX * 0.3;
} else {
int64_t scores[3] = {0};
scores[0] = (1.0 / (i + 1)) * INT64_MAX * 0.3;
int64_t match_score_index = 0;
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
topster->add(&kv);
}
}
}
}

View File

@ -37,6 +37,12 @@ encoded_input_t TextEmbedder::Encode(const std::string& text) {
auto input_ids = tokenizer_->AddSpecialToken(encoded);
auto token_type_ids = tokenizer_->GenerateTypeId(encoded);
auto attention_mask = std::vector<int64_t>(input_ids.size(), 1);
// BERT supports max sequence length of 512
if (input_ids.size() > 512) {
input_ids.resize(512);
token_type_ids.resize(512);
attention_mask.resize(512);
}
return {input_ids, token_type_ids, attention_mask};
}