Fix mean pooling function

This commit is contained in:
ozanarmagan 2023-12-19 14:51:46 +03:00
parent e8a0524536
commit c46d142b66
2 changed files with 18 additions and 6 deletions

View File

@ -45,7 +45,7 @@ class TextEmbedder {
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
std::unique_ptr<RemoteEmbedder> remote_embedder_;
std::string vocab_file_name;
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input, const std::vector<int64_t>& attention_mask);
std::string output_tensor_name;
size_t num_dim;
std::mutex mutex_;

View File

@ -91,16 +91,28 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims)
}
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs) {
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs, const std::vector<int64_t>& attention_mask) {
std::vector<float> pooled_output;
for (int i = 0; i < inputs[0].size(); i++) {
float sum = 0;
for (int j = 0; j < inputs.size(); j++) {
sum += inputs[j][i];
sum += inputs[j][i] * attention_mask[j];
}
pooled_output.push_back(sum / inputs.size());
pooled_output.push_back(sum);
}
// get sum of attention mask
float sum_attention_mask = 0;
for(auto& val : attention_mask) {
sum_attention_mask += val;
}
// divide by sum of attention mask
for(auto& val : pooled_output) {
val /= sum_attention_mask;
}
return pooled_output;
}
@ -171,7 +183,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote
}
output.push_back(temp);
}
auto pooled_output = mean_pooling(output);
auto pooled_output = mean_pooling(output, encoded_input.attention_mask);
return embedding_res_t(pooled_output);
}
}
@ -277,7 +289,7 @@ std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::st
output.push_back(output_row);
}
if(tokenizer_->get_tokenizer_type() != TokenizerType::clip) {
outputs.push_back(embedding_res_t(mean_pooling(output)));
outputs.push_back(embedding_res_t(mean_pooling(output, encoded_inputs.attention_mask[i])));
}
}
}