mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 20:52:50 +08:00
Fix mean pooling function
This commit is contained in:
parent
e8a0524536
commit
c46d142b66
@ -45,7 +45,7 @@ class TextEmbedder {
|
||||
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
|
||||
std::unique_ptr<RemoteEmbedder> remote_embedder_;
|
||||
std::string vocab_file_name;
|
||||
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
|
||||
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input, const std::vector<int64_t>& attention_mask);
|
||||
std::string output_tensor_name;
|
||||
size_t num_dim;
|
||||
std::mutex mutex_;
|
||||
|
@ -91,16 +91,28 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims)
|
||||
}
|
||||
|
||||
|
||||
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs) {
|
||||
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs, const std::vector<int64_t>& attention_mask) {
|
||||
|
||||
std::vector<float> pooled_output;
|
||||
for (int i = 0; i < inputs[0].size(); i++) {
|
||||
float sum = 0;
|
||||
for (int j = 0; j < inputs.size(); j++) {
|
||||
sum += inputs[j][i];
|
||||
sum += inputs[j][i] * attention_mask[j];
|
||||
}
|
||||
pooled_output.push_back(sum / inputs.size());
|
||||
pooled_output.push_back(sum);
|
||||
}
|
||||
|
||||
// get sum of attention mask
|
||||
float sum_attention_mask = 0;
|
||||
for(auto& val : attention_mask) {
|
||||
sum_attention_mask += val;
|
||||
}
|
||||
|
||||
// divide by sum of attention mask
|
||||
for(auto& val : pooled_output) {
|
||||
val /= sum_attention_mask;
|
||||
}
|
||||
|
||||
return pooled_output;
|
||||
}
|
||||
|
||||
@ -171,7 +183,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote
|
||||
}
|
||||
output.push_back(temp);
|
||||
}
|
||||
auto pooled_output = mean_pooling(output);
|
||||
auto pooled_output = mean_pooling(output, encoded_input.attention_mask);
|
||||
return embedding_res_t(pooled_output);
|
||||
}
|
||||
}
|
||||
@ -277,7 +289,7 @@ std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::st
|
||||
output.push_back(output_row);
|
||||
}
|
||||
if(tokenizer_->get_tokenizer_type() != TokenizerType::clip) {
|
||||
outputs.push_back(embedding_res_t(mean_pooling(output)));
|
||||
outputs.push_back(embedding_res_t(mean_pooling(output, encoded_inputs.attention_mask[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user