diff --git a/include/text_embedder.h b/include/text_embedder.h index 817e7cf8..1cf79fa9 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -45,7 +45,7 @@ class TextEmbedder { std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; - static std::vector mean_pooling(const std::vector>& input); + static std::vector mean_pooling(const std::vector>& input, const std::vector& attention_mask); std::string output_tensor_name; size_t num_dim; std::mutex mutex_; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b6780505..1955409d 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -91,16 +91,28 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims) } -std::vector TextEmbedder::mean_pooling(const std::vector>& inputs) { +std::vector TextEmbedder::mean_pooling(const std::vector>& inputs, const std::vector& attention_mask) { std::vector pooled_output; for (int i = 0; i < inputs[0].size(); i++) { float sum = 0; for (int j = 0; j < inputs.size(); j++) { - sum += inputs[j][i]; + sum += inputs[j][i] * attention_mask[j]; } - pooled_output.push_back(sum / inputs.size()); + pooled_output.push_back(sum); } + + // get sum of attention mask + float sum_attention_mask = 0; + for(auto& val : attention_mask) { + sum_attention_mask += val; + } + + // divide by sum of attention mask + for(auto& val : pooled_output) { + val /= sum_attention_mask; + } + return pooled_output; } @@ -171,7 +183,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote } output.push_back(temp); } - auto pooled_output = mean_pooling(output); + auto pooled_output = mean_pooling(output, encoded_input.attention_mask); return embedding_res_t(pooled_output); } } @@ -277,7 +289,7 @@ std::vector TextEmbedder::batch_embed(const std::vectorget_tokenizer_type() != TokenizerType::clip) { - outputs.push_back(embedding_res_t(mean_pooling(output))); + outputs.push_back(embedding_res_t(mean_pooling(output, encoded_inputs.attention_mask[i]))); } } }