Merge pull request #1039 from ozanarmagan/v0.25-join

CUDA support for auto embedding
This commit is contained in:
Kishore Nallan 2023-06-04 15:46:10 +05:30 committed by GitHub
commit 0612792ebc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 256 additions and 30 deletions

113
BUILD
View File

@ -38,6 +38,8 @@ cc_library(
deps = [
":headers",
":onnxruntime_lib",
"@sentencepiece",
"@sentencepiece//:sentencepiece_headers",
"@com_github_brpc_braft//:braft",
"@com_github_brpc_brpc//:brpc",
"@com_github_google_glog//:glog",
@ -53,8 +55,6 @@ cc_library(
"@s2geometry",
"@hnsw",
# "@zip",
"@sentencepiece",
"@sentencepiece//:sentencepiece_headers"
],
)
@ -194,9 +194,7 @@ mkdir -p $INSTALLDIR/lib
mkdir -p $INSTALLDIR/lib/_deps
mkdir -p $INSTALLDIR/lib/_deps/onnx-build
mkdir -p $INSTALLDIR/lib/_deps/re2-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build
mkdir -p $INSTALLDIR/lib/_deps/protobuf-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
@ -207,21 +205,22 @@ mkdir -p $INSTALLDIR/lib/_deps/google_nsync-build
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx.a $INSTALLDIR/lib/_deps/onnx-build
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx_proto.a $INSTALLDIR/lib/_deps/onnx-build
cp $BUILD_TMPDIR/_deps/re2-build/libre2.a $INSTALLDIR/lib/_deps/re2-build
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_base.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_city.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/. $INSTALLDIR/lib/_deps/abseil_cpp-build -r
cp $BUILD_TMPDIR/_deps/google_nsync-build/libnsync_cpp.a $INSTALLDIR/lib/_deps/google_nsync-build
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/libcpuinfo.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build
"""
cmake(
name = "onnxruntime",
lib_source = "@onnx_runtime//:all_srcs",
cache_entries = {'onnxruntime_RUN_ONNX_TESTS':'OFF',
__POSTFIX_WITH_CUDA = __POSTFIX + """
cp $BUILD_TMPDIR/libonnxruntime_providers_shared.so $EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin
cp $BUILD_TMPDIR/libonnxruntime_providers_cuda.so $EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin
"""
load("@cuda_home_repo//:cuda_home.bzl", "CUDA_HOME")
load("@cuda_home_repo//:cudnn_home.bzl", "CUDNN_HOME")
__ONNXRUNTIME_WITHOUT_CUDA = {'onnxruntime_RUN_ONNX_TESTS':'OFF',
'onnxruntime_GENERATE_TEST_REPORTS':'ON',
'onnxruntime_USE_MIMALLOC':'OFF',
'onnxruntime_ENABLE_PYTHON':'OFF',
@ -274,11 +273,86 @@ cmake(
'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF',
'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF',
'CMAKE_BUILD_TYPE':'Release',
},
}
__ONNXRUNTIME_WITH_CUDA = {'onnxruntime_RUN_ONNX_TESTS':'OFF',
'onnxruntime_GENERATE_TEST_REPORTS':'ON',
'onnxruntime_USE_MIMALLOC':'OFF',
'onnxruntime_ENABLE_PYTHON':'OFF',
'onnxruntime_BUILD_CSHARP':'OFF',
'onnxruntime_BUILD_JAVA':'OFF',
'onnxruntime_BUILD_NODEJS':'OFF',
'onnxruntime_BUILD_OBJC':'OFF',
'onnxruntime_BUILD_SHARED_LIB':'OFF',
'onnxruntime_BUILD_APPLE_FRAMEWORK':'OFF',
'onnxruntime_USE_DNNL':'OFF',
'onnxruntime_USE_NNAPI_BUILTIN':'OFF',
'onnxruntime_USE_RKNPU':'OFF',
'onnxruntime_USE_LLVM':'OFF',
'onnxruntime_ENABLE_MICROSOFT_INTERNAL':'OFF',
'onnxruntime_USE_VITISAI':'OFF',
'onnxruntime_USE_TENSORRT':'OFF',
'onnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS':'ON',
'onnxruntime_USE_TENSORRT_BUILTIN_PARSER':'OFF',
'onnxruntime_TENSORRT_PLACEHOLDER_BUILDER':'OFF', 'onnxruntime_USE_TVM':'OFF',
'onnxruntime_TVM_CUDA_RUNTIME':'OFF', 'onnxruntime_TVM_USE_HASH':'OFF',
'onnxruntime_USE_MIGRAPHX':'OFF', 'onnxruntime_CROSS_COMPILING':'OFF',
'onnxruntime_DISABLE_CONTRIB_OPS':'OFF', 'onnxruntime_DISABLE_ML_OPS':'OFF',
'onnxruntime_DISABLE_RTTI':'OFF', 'onnxruntime_DISABLE_EXCEPTIONS':'OFF',
'onnxruntime_MINIMAL_BUILD':'OFF', 'onnxruntime_EXTENDED_MINIMAL_BUILD':'OFF',
'onnxruntime_MINIMAL_BUILD_CUSTOM_OPS':'OFF', 'onnxruntime_REDUCED_OPS_BUILD':'OFF',
'onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS':'OFF', 'onnxruntime_USE_DML':'OFF',
'onnxruntime_USE_WINML':'OFF', 'onnxruntime_BUILD_MS_EXPERIMENTAL_OPS':'OFF',
'onnxruntime_USE_TELEMETRY':'OFF', 'onnxruntime_ENABLE_LTO':'OFF',
'onnxruntime_USE_ACL':'OFF', 'onnxruntime_USE_ACL_1902':'OFF',
'onnxruntime_USE_ACL_1905':'OFF', 'onnxruntime_USE_ACL_1908':'OFF',
'onnxruntime_USE_ACL_2002':'OFF', 'onnxruntime_USE_ARMNN':'OFF',
'onnxruntime_ARMNN_RELU_USE_CPU':'ON', 'onnxruntime_ARMNN_BN_USE_CPU':'ON',
'onnxruntime_ENABLE_NVTX_PROFILE':'OFF', 'onnxruntime_ENABLE_TRAINING':'OFF',
'onnxruntime_ENABLE_TRAINING_OPS':'OFF', 'onnxruntime_ENABLE_TRAINING_APIS':'OFF',
'onnxruntime_ENABLE_CPU_FP16_OPS':'OFF', 'onnxruntime_USE_NCCL':'OFF',
'onnxruntime_BUILD_BENCHMARKS':'OFF', 'onnxruntime_USE_ROCM':'OFF',
'Onnxruntime_GCOV_COVERAGE':'OFF', 'onnxruntime_USE_MPI':'ON',
'onnxruntime_ENABLE_MEMORY_PROFILE':'OFF',
'onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO':'OFF',
'onnxruntime_BUILD_WEBASSEMBLY':'OFF', 'onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING':'ON',
'onnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING':'ON',
'onnxruntime_ENABLE_WEBASSEMBLY_THREADS':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_PROFILING':'OFF',
'onnxruntime_ENABLE_EAGER_MODE':'OFF', 'onnxruntime_ENABLE_LAZY_TENSOR':'OFF',
'onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS':'OFF', 'onnxruntime_ENABLE_CUDA_PROFILING':'OFF',
'onnxruntime_ENABLE_ROCM_PROFILING':'OFF', 'onnxruntime_USE_XNNPACK':'OFF',
'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF',
'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF',
'CMAKE_BUILD_TYPE':'Release', 'onnxruntime_USE_CUDA':'ON', 'onnxruntime_USE_CUDNN':'ON',
'onnxruntime_CUDA_HOME': CUDA_HOME,
'onnxruntime_CUDNN_HOME': CUDNN_HOME,
'CMAKE_CUDA_COMPILER': CUDA_HOME + "/bin/nvcc"
}
config_setting(
name = "with_cuda",
define_values = { "use_cuda": "on" }
)
cmake(
name = "onnxruntime",
lib_source = "@onnx_runtime//:all_srcs",
cache_entries = select({
":with_cuda": __ONNXRUNTIME_WITH_CUDA,
"//conditions:default": __ONNXRUNTIME_WITHOUT_CUDA
}),
working_directory="cmake",
build_args= [
"--config Release",
"-j6"
"-j3"
],
tags=["requires-network","no-sandbox"],
features=["-default_compile_flags","-fno-canonical-system-headers", "-Wno-builtin-macro-redefined"],
@ -304,7 +378,10 @@ cmake(
"_deps/pytorch_cpuinfo-build/libcpuinfo.a",
"_deps/pytorch_cpuinfo-build/deps/clog/libclog.a",
],
postfix_script=__POSTFIX
postfix_script= select({
":with_cuda": __POSTFIX_WITH_CUDA,
"//conditions:default": __POSTFIX
}),
)
cc_library(

View File

@ -1,6 +1,10 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
load("//bazel:onnxruntime_cuda_defs.bzl", "cuda_home_repository")
cuda_home_repository(name = "cuda_home_repo")
git_repository(
name = "com_grail_bazel_compdb",
commit = "58672f5eecd70a2d3ed50016a3abf907701404e0",
@ -40,7 +44,7 @@ git_repository(
new_git_repository(
name="onnx_runtime",
branch= "rel-1.14.0",
branch= "rel-1.14.1",
build_file = "//bazel:onnxruntime.BUILD",
remote= "https://github.com/microsoft/onnxruntime",
patches=["//bazel:onnx.patch"],
@ -305,4 +309,4 @@ http_archive(
sha256 = "ecc406914edf335f0b7fc084ebe6c460c4d6d5175bfdd6688c1c78d9146b8858",
strip_prefix = "elfutils-0.182",
urls = ["https://sourceware.org/elfutils/ftp/0.182/elfutils-0.182.tar.bz2"],
)
)

View File

@ -0,0 +1,9 @@
def cuda_impl(repository_ctx):
repository_ctx.file("cuda_home.bzl", "CUDA_HOME = \"%s\"" % repository_ctx.os.environ.get("CUDA_HOME", ""))
repository_ctx.file("cudnn_home.bzl", "CUDNN_HOME = \"%s\"" % repository_ctx.os.environ.get("CUDNN_HOME", ""))
repository_ctx.file("BUILD", "exports_files([\"cuda_home.bzl\", \"cudnn_home.bzl\"])")
cuda_home_repository = repository_rule(
implementation=cuda_impl,
environ = ["CUDA_HOME", "CUDNN_HOME"],
)

View File

@ -30,9 +30,9 @@ cmake(
install = False,
cache_entries = {
'SPM_USE_BUILTIN_PROTOBUF': 'OFF',
'Protobuf_LIBRARY': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/libprotobuf.a',
'Protobuf_LITE_LIBRARY': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/libprotobuf-lite.a',
'Protobuf_PROTOC_EXECUTABLE': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/protoc',
'Protobuf_LIBRARY': '$INSTALLDIR/../../com_google_protobuf/libprotobuf.a',
'Protobuf_LITE_LIBRARY': '$INSTALLDIR/../../com_google_protobuf/libprotobuf-lite.a',
'Protobuf_PROTOC_EXECUTABLE': '$INSTALLDIR/../../com_google_protobuf/protoc',
'Protobuf_INCLUDE_DIR': '$EXT_BUILD_ROOT/external/com_google_protobuf/src',
'CMAKE_POLICY_DEFAULT_CMP0111':'OLD'
},

View File

@ -27,6 +27,7 @@ class TextEmbedder {
std::unique_ptr<Ort::Session> session_;
Ort::Env env_;
encoded_input_t Encode(const std::string& text);
batch_encoded_input_t batch_encode(const std::vector<std::string>& inputs);
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
std::unique_ptr<RemoteEmbedder> remote_embedder_;
std::string vocab_file_name;

View File

@ -18,6 +18,11 @@ struct encoded_input_t {
std::vector<int64_t> attention_mask;
};
struct batch_encoded_input_t {
std::vector<std::vector<int64_t>> input_ids;
std::vector<std::vector<int64_t>> token_type_ids;
std::vector<std::vector<int64_t>> attention_mask;
};
// Create a base class for all tokenizers to inherit from

View File

@ -6474,7 +6474,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema) {
for(const auto& field : embedding_fields) {
std::vector<std::string> text_to_embed;
std::vector<std::pair<nlohmann::json*, std::string>> texts_to_embed;
auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
for(auto& document : documents) {
std::string text = indexing_prefix;
@ -6489,7 +6489,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
}
}
}
text_to_embed.push_back(text);
texts_to_embed.push_back(std::make_pair(document, text));
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
@ -6497,8 +6497,21 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
if(!embedder_op.ok()) {
return Option<bool>(400, embedder_op.error());
}
// sort texts by length
std::sort(texts_to_embed.begin(), texts_to_embed.end(),
[](const std::pair<nlohmann::json*, std::string>& a,
const std::pair<nlohmann::json*, std::string>& b) {
return a.second.size() < b.second.size();
});
auto embedding_op = embedder_op.get()->batch_embed(text_to_embed);
// get vector of texts
std::vector<std::string> texts;
for(const auto& text_to_embed : texts_to_embed) {
texts.push_back(text_to_embed.second);
}
auto embedding_op = embedder_op.get()->batch_embed(texts);
if(!embedding_op.ok()) {
return Option<bool>(400, embedding_op.error());
@ -6506,7 +6519,9 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
auto embeddings = embedding_op.get();
for(size_t i = 0; i < embeddings.size(); i++) {
(*documents[i])[field.name] = embeddings[i];
auto& embedding = embeddings[i];
auto& document = texts_to_embed[i].first;
(*document)[field.name] = embedding;
}
}
return Option<bool>(true);

View File

@ -9,6 +9,14 @@
TextEmbedder::TextEmbedder(const std::string& model_name) {
// create environment
Ort::SessionOptions session_options;
auto providers = Ort::GetAvailableProviders();
for(auto& provider : providers) {
if(provider == "CUDAExecutionProvider") {
LOG(INFO) << "Using CUDAExecutionProvider";
OrtCUDAProviderOptions cuda_options;
session_options.AppendExecutionProvider_CUDA(cuda_options);
}
}
std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_name);
LOG(INFO) << "Loading model from disk: " << abs_path;
session_ = std::make_unique<Ort::Session>(env_, abs_path.c_str(), session_options);
@ -128,9 +136,83 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<std::vector<float>> outputs;
if(!is_remote()) {
// for now only openai is supported for batch embedding
for(const auto& input : inputs) {
outputs.push_back(Embed(input).get());
for(int i = 0; i < inputs.size(); i += 8) {
auto input_batch = std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + 8, static_cast<int>(inputs.size())));
auto encoded_inputs = batch_encode(input_batch);
// create input tensor object from data values
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
std::vector<Ort::Value> input_tensors;
std::vector<std::vector<int64_t>> input_shapes;
std::vector<const char*> input_node_names = {"input_ids", "attention_mask"};
// If model is DistilBERT or sentencepiece, it has 2 inputs, else it has 3 inputs
if(session_->GetInputCount() == 3) {
input_node_names.push_back("token_type_ids");
}
input_shapes.push_back({static_cast<int64_t>(encoded_inputs.input_ids.size()), static_cast<int64_t>(encoded_inputs.input_ids[0].size())});
input_shapes.push_back({static_cast<int64_t>(encoded_inputs.attention_mask.size()), static_cast<int64_t>(encoded_inputs.attention_mask[0].size())});
if(session_->GetInputCount() == 3) {
input_shapes.push_back({static_cast<int64_t>(encoded_inputs.token_type_ids.size()), static_cast<int64_t>(encoded_inputs.token_type_ids[0].size())});
}
std::vector<int64_t> input_ids_flatten;
std::vector<int64_t> attention_mask_flatten;
std::vector<int64_t> token_type_ids_flatten;
for (int i = 0; i < encoded_inputs.input_ids.size(); i++) {
for (int j = 0; j < encoded_inputs.input_ids[i].size(); j++) {
input_ids_flatten.push_back(encoded_inputs.input_ids[i][j]);
}
}
for (int i = 0; i < encoded_inputs.attention_mask.size(); i++) {
for (int j = 0; j < encoded_inputs.attention_mask[i].size(); j++) {
attention_mask_flatten.push_back(encoded_inputs.attention_mask[i][j]);
}
}
if(session_->GetInputCount() == 3) {
for (int i = 0; i < encoded_inputs.token_type_ids.size(); i++) {
for (int j = 0; j < encoded_inputs.token_type_ids[i].size(); j++) {
token_type_ids_flatten.push_back(encoded_inputs.token_type_ids[i][j]);
}
}
}
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, input_ids_flatten.data(), input_ids_flatten.size(), input_shapes[0].data(), input_shapes[0].size()));
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, attention_mask_flatten.data(), attention_mask_flatten.size(), input_shapes[1].data(), input_shapes[1].size()));
if(session_->GetInputCount() == 3) {
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, token_type_ids_flatten.data(), token_type_ids_flatten.size(), input_shapes[2].data(), input_shapes[2].size()));
}
//LOG(INFO) << "Running model";
// create output tensor object
std::vector<const char*> output_node_names = {output_tensor_name.c_str()};
// if seq length is 0, return empty vector
if(input_shapes[0][1] == 0) {
for(int i = 0; i < input_batch.size(); i++) {
outputs.push_back(std::vector<float>());
}
continue;
}
auto output_tensor = session_->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_tensors.size(), output_node_names.data(), output_node_names.size());
float* data = output_tensor[0].GetTensorMutableData<float>();
// print output tensor shape
auto shape = output_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
for (int i = 0; i < shape[0]; i++) {
std::vector<std::vector<float>> output;
for (int j = 0; j < shape[1]; j++) {
std::vector<float> output_row;
for (int k = 0; k < shape[2]; k++) {
output_row.push_back(data[i * shape[1] * shape[2] + j * shape[2] + k]);
}
output.push_back(output_row);
}
outputs.push_back(mean_pooling(output));
}
}
} else {
auto embed_op = remote_embedder_->batch_embed(inputs);
@ -264,4 +346,37 @@ Option<bool> TextEmbedder::validate_local_or_public_model(const nlohmann::json&
}
return Option<bool>(true);
}
batch_encoded_input_t TextEmbedder::batch_encode(const std::vector<std::string>& inputs) {
batch_encoded_input_t encoded_inputs;
for(auto& input : inputs) {
auto encoded_input = tokenizer_->Encode(input);
encoded_inputs.input_ids.push_back(encoded_input.input_ids);
encoded_inputs.attention_mask.push_back(encoded_input.attention_mask);
encoded_inputs.token_type_ids.push_back(encoded_input.token_type_ids);
}
// Pad inputs
size_t max_input_len = 0;
for(auto& input_ids : encoded_inputs.input_ids) {
if(input_ids.size() > max_input_len) {
max_input_len = input_ids.size();
}
}
for(auto& input_ids : encoded_inputs.input_ids) {
input_ids.resize(max_input_len, 0);
}
for(auto& attention_mask : encoded_inputs.attention_mask) {
attention_mask.resize(max_input_len, 0);
}
for(auto& token_type_ids : encoded_inputs.token_type_ids) {
token_type_ids.resize(max_input_len, 0);
}
return encoded_inputs;
}