diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc similarity index 99% rename from operators/string_utils.cc rename to operators/string_utils_onnx.cc index ecb6713..91dbe76 100644 --- a/operators/string_utils.cc +++ b/operators/string_utils_onnx.cc @@ -2,7 +2,7 @@ #include "farmhash.h" #endif -#include "string_utils.h" +#include "string_utils_onnx.h" std::vector SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) { std::vector result; diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h similarity index 89% rename from operators/string_utils.h rename to operators/string_utils_onnx.h index 5653fbd..6556666 100644 --- a/operators/string_utils.h +++ b/operators/string_utils_onnx.h @@ -4,7 +4,6 @@ #include #include #include -#include "ocos.h" template inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { @@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector -inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept { - MakeStringInternal(ss, static_cast&>(t)); -} - template <> inline void MakeStringInternal(std::ostringstream& ss, const std::vector& t) noexcept { ss << "["; diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc index 324a774..00eac2b 100644 --- a/operators/tokenizer/basic_tokenizer.cc +++ b/operators/tokenizer/basic_tokenizer.cc @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "string_utils.h" +#include "string_utils_onnx.h" #include "basic_tokenizer.hpp" -#include "string_tensor.h" #include #include #include @@ -81,52 +80,3 @@ std::vector BasicTokenizer::Tokenize(ustring text) { push_current_token_and_clear(); return result; } - -KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) { - bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); - bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); - bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); - bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false); - bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true); - - tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars); -} - -void KernelBasicTokenizer::Compute(OrtKernelContext* context) { - // Setup inputs - const OrtValue* input = ort_.KernelContext_GetInput(context, 0); - std::vector input_data; - GetTensorMutableDataString(api_, ort_, context, input, input_data); - - OrtTensorDimensions dimensions(ort_, input); - if (dimensions.size() != 1 && dimensions[0] != 1) { - ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH); - } - - OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); - std::vector result = tokenizer_->Tokenize(ustring(input_data[0])); - - FillTensorDataString(api_, ort_, context, result, output); -} - -void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return CreateKernelImpl(api, info); -}; - -const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; }; - -size_t CustomOpBasicTokenizer::GetInputTypeCount() const { - return 1; -}; - -ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; -}; - -size_t CustomOpBasicTokenizer::GetOutputTypeCount() const { - return 1; -}; - -ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; -}; diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp index 046499e..9fd6f1a 100644 --- a/operators/tokenizer/basic_tokenizer.hpp +++ b/operators/tokenizer/basic_tokenizer.hpp @@ -3,8 +3,7 @@ #pragma once -#include "ocos.h" -#include "string_utils.h" +#include "string_utils_onnx.h" #include "ustring.h" class BasicTokenizer { @@ -19,19 +18,3 @@ class BasicTokenizer { bool tokenize_punctuation_; bool remove_control_chars_; }; - -struct KernelBasicTokenizer : BaseKernel { - KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info); - void Compute(OrtKernelContext* context); - private: - std::shared_ptr tokenizer_; -}; - -struct CustomOpBasicTokenizer : OrtW::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; - const char* GetName() const; - size_t GetInputTypeCount() const; - ONNXTensorElementDataType GetInputType(size_t index) const; - size_t GetOutputTypeCount() const; - ONNXTensorElementDataType GetOutputType(size_t index) const; -}; diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc index b860ba6..9f43c5e 100644 --- a/operators/tokenizer/bert_tokenizer.cc +++ b/operators/tokenizer/bert_tokenizer.cc @@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) { auto it = vocab_.find(utf8_token); if (it == vocab_.end()) { - ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION); + std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token); + return -1; } return it->second; @@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T } } -KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) { - std::string vocab = ort_.KernelInfoGetAttribute(info, "vocab_file"); - bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); - bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true); - std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); - std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); - std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); - std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); - std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); - bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); - bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); - std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); - std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first")); - int32_t max_len = static_cast(TryToGetAttributeWithDefault("max_length", int64_t(-1))); - - tokenizer_ = std::make_unique( - vocab, do_lower_case, do_basic_tokenize, ustring(unk_token), - ustring(sep_token), ustring(pad_token), ustring(cls_token), - ustring(mask_token), tokenize_chinese_chars, strip_accents, - ustring(suffix_indicator), max_len, truncation_strategy_name); -} - -void KernelBertTokenizer::Compute(OrtKernelContext* context) { - // Setup inputs - const OrtValue* input = ort_.KernelContext_GetInput(context, 0); - std::vector input_data; - GetTensorMutableDataString(api_, ort_, context, input, input_data); - - if (input_data.size() != 1 && input_data.size() != 2) { - ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH); - } - std::vector input_ids; - std::vector token_type_ids; - - if (input_data.size() == 1) { - std::vector tokens = tokenizer_->Tokenize(ustring(input_data[0])); - std::vector encoded = tokenizer_->Encode(tokens); - tokenizer_->Truncate(encoded); - input_ids = tokenizer_->AddSpecialToken(encoded); - token_type_ids = tokenizer_->GenerateTypeId(encoded); - } else { - std::vector tokens1 = tokenizer_->Tokenize(ustring(input_data[0])); - std::vector tokens2 = tokenizer_->Tokenize(ustring(input_data[1])); - std::vector encoded1 = tokenizer_->Encode(tokens1); - std::vector encoded2 = tokenizer_->Encode(tokens2); - input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2); - token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2); - } - - std::vector attention_mask(input_ids.size(), 1); - - std::vector output_dim{static_cast(input_ids.size())}; - - SetOutput(context, 0, output_dim, input_ids); - SetOutput(context, 1, output_dim, token_type_ids); - SetOutput(context, 2, output_dim, attention_mask); -} - -void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return CreateKernelImpl(api, info); -} - -const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; } - -size_t CustomOpBertTokenizer::GetInputTypeCount() const { - return 1; -} - -ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; -} - -size_t CustomOpBertTokenizer::GetOutputTypeCount() const { - return 3; -} - -ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; -} - -KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {} - -void KernelHfBertTokenizer::Compute(OrtKernelContext* context) { - // Setup inputs - const OrtValue *const input = ort_.KernelContext_GetInput(context, 0); - std::vector input_data; - GetTensorMutableDataString(api_, ort_, context, input, input_data); - - if (input_data.size() != 2) { - ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH); - } - - std::vector tokens1 = tokenizer_->Tokenize(ustring(input_data[0])); - std::vector tokens2 = tokenizer_->Tokenize(ustring(input_data[1])); - std::vector encoded1 = tokenizer_->Encode(tokens1); - std::vector encoded2 = tokenizer_->Encode(tokens2); - std::vector input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2); - std::vector token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2); - std::vector attention_mask(input_ids.size(), 1LL); - - const std::vector outer_dims{1LL, static_cast(input_ids.size())}; - const std::vector inner_dims{1LL}; - for (int32_t i = 0; i < 3; ++i) { - OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size()); - OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value); - ort_.SetDimensions(info, inner_dims.data(), inner_dims.size()); - ort_.ReleaseTensorTypeAndShapeInfo(info); - } - - SetOutput(context, 0, outer_dims, input_ids); - SetOutput(context, 1, outer_dims, attention_mask); - SetOutput(context, 2, outer_dims, token_type_ids); -} - -void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return CreateKernelImpl(api, info); -} - -const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; } - -size_t CustomOpHfBertTokenizer::GetInputTypeCount() const { - return 1; -} - -ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; -} - -size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const { - return 3; -} - -ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; -} diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp index 6dfcd84..10565e4 100644 --- a/operators/tokenizer/bert_tokenizer.hpp +++ b/operators/tokenizer/bert_tokenizer.hpp @@ -3,12 +3,11 @@ #pragma once +#include #include #include -#include "ocos.h" #include "ustring.h" -#include "string_utils.h" -#include "string_tensor.h" +#include "string_utils_onnx.h" #include "basic_tokenizer.hpp" class BertTokenizerVocab final { @@ -89,33 +88,4 @@ class BertTokenizer final { std::unique_ptr wordpiece_tokenizer_; }; -struct KernelBertTokenizer : BaseKernel { - KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info); - void Compute(OrtKernelContext* context); - protected: - std::unique_ptr tokenizer_; -}; - -struct CustomOpBertTokenizer : OrtW::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; - const char* GetName() const; - size_t GetInputTypeCount() const; - ONNXTensorElementDataType GetInputType(size_t index) const; - size_t GetOutputTypeCount() const; - ONNXTensorElementDataType GetOutputType(size_t index) const; -}; - -struct KernelHfBertTokenizer : KernelBertTokenizer { - KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info); - void Compute(OrtKernelContext* context); -}; - -struct CustomOpHfBertTokenizer : OrtW::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; - const char* GetName() const; - size_t GetInputTypeCount() const; - ONNXTensorElementDataType GetInputType(size_t index) const; - size_t GetOutputTypeCount() const; - ONNXTensorElementDataType GetOutputType(size_t index) const; -};