diff --git a/src/index.cpp b/src/index.cpp index 4944c1e4..e4feba90 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3034,20 +3034,47 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor); } + std::vector> vec_results; for (const auto& dist_label : dist_labels) { - uint32 seq_id = dist_label.second; + uint32_t seq_id = dist_label.second; auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - auto score = (1.0 - vec_dist_score) * 100000000000.0; - - auto found = topster->kv_map.find(seq_id); - - if (found != topster->kv_map.end() && found->second->match_score_index >= 0 && found->second->match_score_index <= 2) { - found->second->scores[found->second->match_score_index] += score; + auto score = (1.0 - vec_dist_score) * 100.0; + + vec_results.emplace_back(seq_id, score); + } + std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) { + return a.second > b.second; + }); + + topster->sort(); + // Reciprocal rank fusion + for(uint32_t i = 0; i < topster->size; i++) { + auto result = topster->getKV(i); + if(result->match_score_index < 0 || result->match_score_index > 2) { + continue; } + result->scores[result->match_score_index] = (1.0 / (i + 1)) * INT64_MAX * 0.7; } + for(int i = 0; i < vec_results.size(); i++) { + auto& result = vec_results[i]; + auto doc_id = result.first; + + auto result_it = topster->kv_map.find(doc_id); + + if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) { + auto result = result_it->second; + result->scores[result->match_score_index] += (1.0 / (i + 1)) * INT64_MAX * 0.3; + } else { + int64_t scores[3] = {0}; + scores[0] = (1.0 / (i + 1)) * INT64_MAX * 0.3; + int64_t match_score_index = 0; + KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); + topster->add(&kv); + } + } } } diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 92b3ebfb..285ebd71 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -37,6 +37,12 @@ encoded_input_t TextEmbedder::Encode(const std::string& text) { auto input_ids = tokenizer_->AddSpecialToken(encoded); auto token_type_ids = tokenizer_->GenerateTypeId(encoded); auto attention_mask = std::vector(input_ids.size(), 1); + // BERT supports max sequence length of 512 + if (input_ids.size() > 512) { + input_ids.resize(512); + token_type_ids.resize(512); + attention_mask.resize(512); + } return {input_ids, token_type_ids, attention_mask}; }