diff --git a/BUILD b/BUILD index a2ca8ffa..f9208da6 100644 --- a/BUILD +++ b/BUILD @@ -38,6 +38,8 @@ cc_library( deps = [ ":headers", ":onnxruntime_lib", + "@sentencepiece", + "@sentencepiece//:sentencepiece_headers", "@com_github_brpc_braft//:braft", "@com_github_brpc_brpc//:brpc", "@com_github_google_glog//:glog", @@ -53,8 +55,6 @@ cc_library( "@s2geometry", "@hnsw", # "@zip", - "@sentencepiece", - "@sentencepiece//:sentencepiece_headers" ], ) @@ -194,9 +194,7 @@ mkdir -p $INSTALLDIR/lib mkdir -p $INSTALLDIR/lib/_deps mkdir -p $INSTALLDIR/lib/_deps/onnx-build mkdir -p $INSTALLDIR/lib/_deps/re2-build -mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build -mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl -mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base +mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build mkdir -p $INSTALLDIR/lib/_deps/protobuf-build mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash @@ -207,21 +205,22 @@ mkdir -p $INSTALLDIR/lib/_deps/google_nsync-build cp $BUILD_TMPDIR/_deps/onnx-build/libonnx.a $INSTALLDIR/lib/_deps/onnx-build cp $BUILD_TMPDIR/_deps/onnx-build/libonnx_proto.a $INSTALLDIR/lib/_deps/onnx-build cp $BUILD_TMPDIR/_deps/re2-build/libre2.a $INSTALLDIR/lib/_deps/re2-build -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_base.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_city.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash +cp $BUILD_TMPDIR/_deps/abseil_cpp-build/. $INSTALLDIR/lib/_deps/abseil_cpp-build -r cp $BUILD_TMPDIR/_deps/google_nsync-build/libnsync_cpp.a $INSTALLDIR/lib/_deps/google_nsync-build cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/libcpuinfo.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build """ -cmake( - name = "onnxruntime", - lib_source = "@onnx_runtime//:all_srcs", - cache_entries = {'onnxruntime_RUN_ONNX_TESTS':'OFF', +__POSTFIX_WITH_CUDA = __POSTFIX + """ +cp $BUILD_TMPDIR/libonnxruntime_providers_shared.so $EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin +cp $BUILD_TMPDIR/libonnxruntime_providers_cuda.so $EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin +""" + + +load("@cuda_home_repo//:cuda_home.bzl", "CUDA_HOME") +load("@cuda_home_repo//:cudnn_home.bzl", "CUDNN_HOME") + +__ONNXRUNTIME_WITHOUT_CUDA = {'onnxruntime_RUN_ONNX_TESTS':'OFF', 'onnxruntime_GENERATE_TEST_REPORTS':'ON', 'onnxruntime_USE_MIMALLOC':'OFF', 'onnxruntime_ENABLE_PYTHON':'OFF', @@ -274,11 +273,86 @@ cmake( 'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF', 'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF', 'CMAKE_BUILD_TYPE':'Release', -}, +} + + +__ONNXRUNTIME_WITH_CUDA = {'onnxruntime_RUN_ONNX_TESTS':'OFF', +'onnxruntime_GENERATE_TEST_REPORTS':'ON', +'onnxruntime_USE_MIMALLOC':'OFF', +'onnxruntime_ENABLE_PYTHON':'OFF', +'onnxruntime_BUILD_CSHARP':'OFF', +'onnxruntime_BUILD_JAVA':'OFF', +'onnxruntime_BUILD_NODEJS':'OFF', +'onnxruntime_BUILD_OBJC':'OFF', +'onnxruntime_BUILD_SHARED_LIB':'OFF', +'onnxruntime_BUILD_APPLE_FRAMEWORK':'OFF', +'onnxruntime_USE_DNNL':'OFF', +'onnxruntime_USE_NNAPI_BUILTIN':'OFF', +'onnxruntime_USE_RKNPU':'OFF', +'onnxruntime_USE_LLVM':'OFF', +'onnxruntime_ENABLE_MICROSOFT_INTERNAL':'OFF', +'onnxruntime_USE_VITISAI':'OFF', +'onnxruntime_USE_TENSORRT':'OFF', +'onnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS':'ON', +'onnxruntime_USE_TENSORRT_BUILTIN_PARSER':'OFF', +'onnxruntime_TENSORRT_PLACEHOLDER_BUILDER':'OFF', 'onnxruntime_USE_TVM':'OFF', +'onnxruntime_TVM_CUDA_RUNTIME':'OFF', 'onnxruntime_TVM_USE_HASH':'OFF', +'onnxruntime_USE_MIGRAPHX':'OFF', 'onnxruntime_CROSS_COMPILING':'OFF', +'onnxruntime_DISABLE_CONTRIB_OPS':'OFF', 'onnxruntime_DISABLE_ML_OPS':'OFF', +'onnxruntime_DISABLE_RTTI':'OFF', 'onnxruntime_DISABLE_EXCEPTIONS':'OFF', +'onnxruntime_MINIMAL_BUILD':'OFF', 'onnxruntime_EXTENDED_MINIMAL_BUILD':'OFF', +'onnxruntime_MINIMAL_BUILD_CUSTOM_OPS':'OFF', 'onnxruntime_REDUCED_OPS_BUILD':'OFF', +'onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS':'OFF', 'onnxruntime_USE_DML':'OFF', +'onnxruntime_USE_WINML':'OFF', 'onnxruntime_BUILD_MS_EXPERIMENTAL_OPS':'OFF', +'onnxruntime_USE_TELEMETRY':'OFF', 'onnxruntime_ENABLE_LTO':'OFF', +'onnxruntime_USE_ACL':'OFF', 'onnxruntime_USE_ACL_1902':'OFF', +'onnxruntime_USE_ACL_1905':'OFF', 'onnxruntime_USE_ACL_1908':'OFF', +'onnxruntime_USE_ACL_2002':'OFF', 'onnxruntime_USE_ARMNN':'OFF', +'onnxruntime_ARMNN_RELU_USE_CPU':'ON', 'onnxruntime_ARMNN_BN_USE_CPU':'ON', +'onnxruntime_ENABLE_NVTX_PROFILE':'OFF', 'onnxruntime_ENABLE_TRAINING':'OFF', +'onnxruntime_ENABLE_TRAINING_OPS':'OFF', 'onnxruntime_ENABLE_TRAINING_APIS':'OFF', +'onnxruntime_ENABLE_CPU_FP16_OPS':'OFF', 'onnxruntime_USE_NCCL':'OFF', +'onnxruntime_BUILD_BENCHMARKS':'OFF', 'onnxruntime_USE_ROCM':'OFF', +'Onnxruntime_GCOV_COVERAGE':'OFF', 'onnxruntime_USE_MPI':'ON', +'onnxruntime_ENABLE_MEMORY_PROFILE':'OFF', +'onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO':'OFF', +'onnxruntime_BUILD_WEBASSEMBLY':'OFF', 'onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB':'OFF', +'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING':'ON', +'onnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING':'OFF', +'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING':'ON', +'onnxruntime_ENABLE_WEBASSEMBLY_THREADS':'OFF', +'onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO':'OFF', +'onnxruntime_ENABLE_WEBASSEMBLY_PROFILING':'OFF', +'onnxruntime_ENABLE_EAGER_MODE':'OFF', 'onnxruntime_ENABLE_LAZY_TENSOR':'OFF', +'onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS':'OFF', 'onnxruntime_ENABLE_CUDA_PROFILING':'OFF', +'onnxruntime_ENABLE_ROCM_PROFILING':'OFF', 'onnxruntime_USE_XNNPACK':'OFF', +'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF', +'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF', +'CMAKE_BUILD_TYPE':'Release', 'onnxruntime_USE_CUDA':'ON', 'onnxruntime_USE_CUDNN':'ON', +'onnxruntime_CUDA_HOME': CUDA_HOME, +'onnxruntime_CUDNN_HOME': CUDNN_HOME, +'CMAKE_CUDA_COMPILER': CUDA_HOME + "/bin/nvcc" +} + + +config_setting( + name = "with_cuda", + define_values = { "use_cuda": "on" } +) + + + +cmake( + name = "onnxruntime", + lib_source = "@onnx_runtime//:all_srcs", + cache_entries = select({ + ":with_cuda": __ONNXRUNTIME_WITH_CUDA, + "//conditions:default": __ONNXRUNTIME_WITHOUT_CUDA + }), working_directory="cmake", build_args= [ "--config Release", - "-j6" + "-j3" ], tags=["requires-network","no-sandbox"], features=["-default_compile_flags","-fno-canonical-system-headers", "-Wno-builtin-macro-redefined"], @@ -304,7 +378,10 @@ cmake( "_deps/pytorch_cpuinfo-build/libcpuinfo.a", "_deps/pytorch_cpuinfo-build/deps/clog/libclog.a", ], - postfix_script=__POSTFIX + postfix_script= select({ + ":with_cuda": __POSTFIX_WITH_CUDA, + "//conditions:default": __POSTFIX + }), ) cc_library( diff --git a/WORKSPACE b/WORKSPACE index 2948d081..02096e28 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,6 +1,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository") +load("//bazel:onnxruntime_cuda_defs.bzl", "cuda_home_repository") + +cuda_home_repository(name = "cuda_home_repo") + git_repository( name = "com_grail_bazel_compdb", commit = "58672f5eecd70a2d3ed50016a3abf907701404e0", @@ -40,7 +44,7 @@ git_repository( new_git_repository( name="onnx_runtime", - branch= "rel-1.14.0", + branch= "rel-1.14.1", build_file = "//bazel:onnxruntime.BUILD", remote= "https://github.com/microsoft/onnxruntime", patches=["//bazel:onnx.patch"], @@ -305,4 +309,4 @@ http_archive( sha256 = "ecc406914edf335f0b7fc084ebe6c460c4d6d5175bfdd6688c1c78d9146b8858", strip_prefix = "elfutils-0.182", urls = ["https://sourceware.org/elfutils/ftp/0.182/elfutils-0.182.tar.bz2"], -) +) \ No newline at end of file diff --git a/bazel/onnxruntime_cuda_defs.bzl b/bazel/onnxruntime_cuda_defs.bzl new file mode 100755 index 00000000..6e155368 --- /dev/null +++ b/bazel/onnxruntime_cuda_defs.bzl @@ -0,0 +1,9 @@ +def cuda_impl(repository_ctx): + repository_ctx.file("cuda_home.bzl", "CUDA_HOME = \"%s\"" % repository_ctx.os.environ.get("CUDA_HOME", "")) + repository_ctx.file("cudnn_home.bzl", "CUDNN_HOME = \"%s\"" % repository_ctx.os.environ.get("CUDNN_HOME", "")) + repository_ctx.file("BUILD", "exports_files([\"cuda_home.bzl\", \"cudnn_home.bzl\"])") + +cuda_home_repository = repository_rule( + implementation=cuda_impl, + environ = ["CUDA_HOME", "CUDNN_HOME"], +) diff --git a/bazel/sentencepiece.BUILD b/bazel/sentencepiece.BUILD index babe2013..6958221c 100644 --- a/bazel/sentencepiece.BUILD +++ b/bazel/sentencepiece.BUILD @@ -30,9 +30,9 @@ cmake( install = False, cache_entries = { 'SPM_USE_BUILTIN_PROTOBUF': 'OFF', - 'Protobuf_LIBRARY': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/libprotobuf.a', - 'Protobuf_LITE_LIBRARY': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/libprotobuf-lite.a', - 'Protobuf_PROTOC_EXECUTABLE': '$EXT_BUILD_ROOT/bazel-out/k8-fastbuild/bin/external/com_google_protobuf/protoc', + 'Protobuf_LIBRARY': '$INSTALLDIR/../../com_google_protobuf/libprotobuf.a', + 'Protobuf_LITE_LIBRARY': '$INSTALLDIR/../../com_google_protobuf/libprotobuf-lite.a', + 'Protobuf_PROTOC_EXECUTABLE': '$INSTALLDIR/../../com_google_protobuf/protoc', 'Protobuf_INCLUDE_DIR': '$EXT_BUILD_ROOT/external/com_google_protobuf/src', 'CMAKE_POLICY_DEFAULT_CMP0111':'OLD' }, diff --git a/include/text_embedder.h b/include/text_embedder.h index 38666a2e..4bae7746 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -27,6 +27,7 @@ class TextEmbedder { std::unique_ptr session_; Ort::Env env_; encoded_input_t Encode(const std::string& text); + batch_encoded_input_t batch_encode(const std::vector& inputs); std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; diff --git a/include/text_embedder_tokenizer.h b/include/text_embedder_tokenizer.h index bd41717a..a9b3a41b 100644 --- a/include/text_embedder_tokenizer.h +++ b/include/text_embedder_tokenizer.h @@ -18,6 +18,11 @@ struct encoded_input_t { std::vector attention_mask; }; +struct batch_encoded_input_t { + std::vector> input_ids; + std::vector> token_type_ids; + std::vector> attention_mask; +}; // Create a base class for all tokenizers to inherit from diff --git a/src/index.cpp b/src/index.cpp index 80c0ab9e..ff4bdd03 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6474,7 +6474,7 @@ Option Index::batch_embed_fields(std::vector& documents, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema) { for(const auto& field : embedding_fields) { - std::vector text_to_embed; + std::vector> texts_to_embed; auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]); for(auto& document : documents) { std::string text = indexing_prefix; @@ -6489,7 +6489,7 @@ Option Index::batch_embed_fields(std::vector& documents, } } } - text_to_embed.push_back(text); + texts_to_embed.push_back(std::make_pair(document, text)); } TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]); @@ -6497,8 +6497,21 @@ Option Index::batch_embed_fields(std::vector& documents, if(!embedder_op.ok()) { return Option(400, embedder_op.error()); } + + // sort texts by length + std::sort(texts_to_embed.begin(), texts_to_embed.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second.size() < b.second.size(); + }); - auto embedding_op = embedder_op.get()->batch_embed(text_to_embed); + // get vector of texts + std::vector texts; + for(const auto& text_to_embed : texts_to_embed) { + texts.push_back(text_to_embed.second); + } + + auto embedding_op = embedder_op.get()->batch_embed(texts); if(!embedding_op.ok()) { return Option(400, embedding_op.error()); @@ -6506,7 +6519,9 @@ Option Index::batch_embed_fields(std::vector& documents, auto embeddings = embedding_op.get(); for(size_t i = 0; i < embeddings.size(); i++) { - (*documents[i])[field.name] = embeddings[i]; + auto& embedding = embeddings[i]; + auto& document = texts_to_embed[i].first; + (*document)[field.name] = embedding; } } return Option(true); diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index a29c4062..9e168424 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -9,6 +9,14 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { // create environment Ort::SessionOptions session_options; + auto providers = Ort::GetAvailableProviders(); + for(auto& provider : providers) { + if(provider == "CUDAExecutionProvider") { + LOG(INFO) << "Using CUDAExecutionProvider"; + OrtCUDAProviderOptions cuda_options; + session_options.AppendExecutionProvider_CUDA(cuda_options); + } + } std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_name); LOG(INFO) << "Loading model from disk: " << abs_path; session_ = std::make_unique(env_, abs_path.c_str(), session_options); @@ -128,9 +136,83 @@ Option> TextEmbedder::Embed(const std::string& text) { Option>> TextEmbedder::batch_embed(const std::vector& inputs) { std::vector> outputs; if(!is_remote()) { - // for now only openai is supported for batch embedding - for(const auto& input : inputs) { - outputs.push_back(Embed(input).get()); + for(int i = 0; i < inputs.size(); i += 8) { + auto input_batch = std::vector(inputs.begin() + i, inputs.begin() + std::min(i + 8, static_cast(inputs.size()))); + auto encoded_inputs = batch_encode(input_batch); + + // create input tensor object from data values + Ort::AllocatorWithDefaultOptions allocator; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + std::vector input_tensors; + std::vector> input_shapes; + std::vector input_node_names = {"input_ids", "attention_mask"}; + // If model is DistilBERT or sentencepiece, it has 2 inputs, else it has 3 inputs + if(session_->GetInputCount() == 3) { + input_node_names.push_back("token_type_ids"); + } + input_shapes.push_back({static_cast(encoded_inputs.input_ids.size()), static_cast(encoded_inputs.input_ids[0].size())}); + input_shapes.push_back({static_cast(encoded_inputs.attention_mask.size()), static_cast(encoded_inputs.attention_mask[0].size())}); + if(session_->GetInputCount() == 3) { + input_shapes.push_back({static_cast(encoded_inputs.token_type_ids.size()), static_cast(encoded_inputs.token_type_ids[0].size())}); + } + + std::vector input_ids_flatten; + std::vector attention_mask_flatten; + std::vector token_type_ids_flatten; + + for (int i = 0; i < encoded_inputs.input_ids.size(); i++) { + for (int j = 0; j < encoded_inputs.input_ids[i].size(); j++) { + input_ids_flatten.push_back(encoded_inputs.input_ids[i][j]); + } + } + + for (int i = 0; i < encoded_inputs.attention_mask.size(); i++) { + for (int j = 0; j < encoded_inputs.attention_mask[i].size(); j++) { + attention_mask_flatten.push_back(encoded_inputs.attention_mask[i][j]); + } + } + + if(session_->GetInputCount() == 3) { + for (int i = 0; i < encoded_inputs.token_type_ids.size(); i++) { + for (int j = 0; j < encoded_inputs.token_type_ids[i].size(); j++) { + token_type_ids_flatten.push_back(encoded_inputs.token_type_ids[i][j]); + } + } + } + + input_tensors.push_back(Ort::Value::CreateTensor(memory_info, input_ids_flatten.data(), input_ids_flatten.size(), input_shapes[0].data(), input_shapes[0].size())); + input_tensors.push_back(Ort::Value::CreateTensor(memory_info, attention_mask_flatten.data(), attention_mask_flatten.size(), input_shapes[1].data(), input_shapes[1].size())); + if(session_->GetInputCount() == 3) { + input_tensors.push_back(Ort::Value::CreateTensor(memory_info, token_type_ids_flatten.data(), token_type_ids_flatten.size(), input_shapes[2].data(), input_shapes[2].size())); + } + + //LOG(INFO) << "Running model"; + // create output tensor object + std::vector output_node_names = {output_tensor_name.c_str()}; + + // if seq length is 0, return empty vector + if(input_shapes[0][1] == 0) { + for(int i = 0; i < input_batch.size(); i++) { + outputs.push_back(std::vector()); + } + continue; + } + + auto output_tensor = session_->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_tensors.size(), output_node_names.data(), output_node_names.size()); + float* data = output_tensor[0].GetTensorMutableData(); + // print output tensor shape + auto shape = output_tensor[0].GetTensorTypeAndShapeInfo().GetShape(); + for (int i = 0; i < shape[0]; i++) { + std::vector> output; + for (int j = 0; j < shape[1]; j++) { + std::vector output_row; + for (int k = 0; k < shape[2]; k++) { + output_row.push_back(data[i * shape[1] * shape[2] + j * shape[2] + k]); + } + output.push_back(output_row); + } + outputs.push_back(mean_pooling(output)); + } } } else { auto embed_op = remote_embedder_->batch_embed(inputs); @@ -264,4 +346,37 @@ Option TextEmbedder::validate_local_or_public_model(const nlohmann::json& } return Option(true); +} + + +batch_encoded_input_t TextEmbedder::batch_encode(const std::vector& inputs) { + batch_encoded_input_t encoded_inputs; + for(auto& input : inputs) { + auto encoded_input = tokenizer_->Encode(input); + encoded_inputs.input_ids.push_back(encoded_input.input_ids); + encoded_inputs.attention_mask.push_back(encoded_input.attention_mask); + encoded_inputs.token_type_ids.push_back(encoded_input.token_type_ids); + } + + // Pad inputs + size_t max_input_len = 0; + for(auto& input_ids : encoded_inputs.input_ids) { + if(input_ids.size() > max_input_len) { + max_input_len = input_ids.size(); + } + } + + for(auto& input_ids : encoded_inputs.input_ids) { + input_ids.resize(max_input_len, 0); + } + + for(auto& attention_mask : encoded_inputs.attention_mask) { + attention_mask.resize(max_input_len, 0); + } + + for(auto& token_type_ids : encoded_inputs.token_type_ids) { + token_type_ids.resize(max_input_len, 0); + } + + return encoded_inputs; } \ No newline at end of file