mirror of
https://github.com/typesense/typesense.git
synced 2025-05-16 11:28:44 +08:00
352 lines
14 KiB
Diff
352 lines
14 KiB
Diff
diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc
|
|
similarity index 99%
|
|
rename from operators/string_utils.cc
|
|
rename to operators/string_utils_onnx.cc
|
|
index ecb6713..91dbe76 100644
|
|
--- a/operators/string_utils.cc
|
|
+++ b/operators/string_utils_onnx.cc
|
|
@@ -2,7 +2,7 @@
|
|
#include "farmhash.h"
|
|
#endif
|
|
|
|
-#include "string_utils.h"
|
|
+#include "string_utils_onnx.h"
|
|
|
|
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
|
|
std::vector<std::string_view> result;
|
|
diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h
|
|
similarity index 89%
|
|
rename from operators/string_utils.h
|
|
rename to operators/string_utils_onnx.h
|
|
index 5653fbd..6556666 100644
|
|
--- a/operators/string_utils.h
|
|
+++ b/operators/string_utils_onnx.h
|
|
@@ -4,7 +4,6 @@
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include <vector>
|
|
-#include "ocos.h"
|
|
|
|
template <typename T>
|
|
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
|
@@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
|
|
ss << "]";
|
|
}
|
|
|
|
-template <>
|
|
-inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
|
|
- MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
|
|
-}
|
|
-
|
|
template <>
|
|
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
|
ss << "[";
|
|
diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc
|
|
index 324a774..00eac2b 100644
|
|
--- a/operators/tokenizer/basic_tokenizer.cc
|
|
+++ b/operators/tokenizer/basic_tokenizer.cc
|
|
@@ -1,9 +1,8 @@
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
-#include "string_utils.h"
|
|
+#include "string_utils_onnx.h"
|
|
#include "basic_tokenizer.hpp"
|
|
-#include "string_tensor.h"
|
|
#include <vector>
|
|
#include <locale>
|
|
#include <codecvt>
|
|
@@ -81,52 +80,3 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|
push_current_token_and_clear();
|
|
return result;
|
|
}
|
|
-
|
|
-KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
|
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
|
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
|
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
|
- bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
|
- bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
|
-
|
|
- tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
|
-}
|
|
-
|
|
-void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
|
- // Setup inputs
|
|
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
|
- std::vector<std::string> input_data;
|
|
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
|
-
|
|
- OrtTensorDimensions dimensions(ort_, input);
|
|
- if (dimensions.size() != 1 && dimensions[0] != 1) {
|
|
- ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
|
- }
|
|
-
|
|
- OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
|
- std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
|
|
-
|
|
- FillTensorDataString(api_, ort_, context, result, output);
|
|
-}
|
|
-
|
|
-void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
|
- return CreateKernelImpl(api, info);
|
|
-};
|
|
-
|
|
-const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
|
-
|
|
-size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
|
- return 1;
|
|
-};
|
|
-
|
|
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
-};
|
|
-
|
|
-size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
|
|
- return 1;
|
|
-};
|
|
-
|
|
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
-};
|
|
diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp
|
|
index 046499e..9fd6f1a 100644
|
|
--- a/operators/tokenizer/basic_tokenizer.hpp
|
|
+++ b/operators/tokenizer/basic_tokenizer.hpp
|
|
@@ -3,8 +3,7 @@
|
|
|
|
#pragma once
|
|
|
|
-#include "ocos.h"
|
|
-#include "string_utils.h"
|
|
+#include "string_utils_onnx.h"
|
|
#include "ustring.h"
|
|
|
|
class BasicTokenizer {
|
|
@@ -19,19 +18,3 @@ class BasicTokenizer {
|
|
bool tokenize_punctuation_;
|
|
bool remove_control_chars_;
|
|
};
|
|
-
|
|
-struct KernelBasicTokenizer : BaseKernel {
|
|
- KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
|
- void Compute(OrtKernelContext* context);
|
|
- private:
|
|
- std::shared_ptr<BasicTokenizer> tokenizer_;
|
|
-};
|
|
-
|
|
-struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
|
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
|
- const char* GetName() const;
|
|
- size_t GetInputTypeCount() const;
|
|
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
|
- size_t GetOutputTypeCount() const;
|
|
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
|
-};
|
|
diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc
|
|
index b860ba6..9f43c5e 100644
|
|
--- a/operators/tokenizer/bert_tokenizer.cc
|
|
+++ b/operators/tokenizer/bert_tokenizer.cc
|
|
@@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
|
|
|
auto it = vocab_.find(utf8_token);
|
|
if (it == vocab_.end()) {
|
|
- ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
|
+ std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token);
|
|
+ return -1;
|
|
}
|
|
|
|
return it->second;
|
|
@@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
|
}
|
|
}
|
|
|
|
-KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
|
- std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
|
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
|
- bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
|
- std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
|
- std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
|
- std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
|
|
- std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
|
|
- std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
|
|
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
|
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
|
- std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
|
- std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
|
- int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
|
|
-
|
|
- tokenizer_ = std::make_unique<BertTokenizer>(
|
|
- vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
|
- ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
|
- ustring(mask_token), tokenize_chinese_chars, strip_accents,
|
|
- ustring(suffix_indicator), max_len, truncation_strategy_name);
|
|
-}
|
|
-
|
|
-void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
|
- // Setup inputs
|
|
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
|
- std::vector<std::string> input_data;
|
|
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
|
-
|
|
- if (input_data.size() != 1 && input_data.size() != 2) {
|
|
- ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
|
- }
|
|
- std::vector<int64_t> input_ids;
|
|
- std::vector<int64_t> token_type_ids;
|
|
-
|
|
- if (input_data.size() == 1) {
|
|
- std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
|
|
- std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
|
|
- tokenizer_->Truncate(encoded);
|
|
- input_ids = tokenizer_->AddSpecialToken(encoded);
|
|
- token_type_ids = tokenizer_->GenerateTypeId(encoded);
|
|
- } else {
|
|
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
|
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
|
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
|
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
|
- input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
|
- token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
|
- }
|
|
-
|
|
- std::vector<int64_t> attention_mask(input_ids.size(), 1);
|
|
-
|
|
- std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
|
|
-
|
|
- SetOutput(context, 0, output_dim, input_ids);
|
|
- SetOutput(context, 1, output_dim, token_type_ids);
|
|
- SetOutput(context, 2, output_dim, attention_mask);
|
|
-}
|
|
-
|
|
-void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
|
- return CreateKernelImpl(api, info);
|
|
-}
|
|
-
|
|
-const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
|
|
-
|
|
-size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
|
- return 1;
|
|
-}
|
|
-
|
|
-ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
-}
|
|
-
|
|
-size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
|
- return 3;
|
|
-}
|
|
-
|
|
-ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
-}
|
|
-
|
|
-KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
|
-
|
|
-void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
|
- // Setup inputs
|
|
- const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
|
|
- std::vector<std::string> input_data;
|
|
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
|
-
|
|
- if (input_data.size() != 2) {
|
|
- ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
|
- }
|
|
-
|
|
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
|
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
|
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
|
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
|
- std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
|
- std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
|
- std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
|
|
-
|
|
- const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
|
|
- const std::vector<int64_t> inner_dims{1LL};
|
|
- for (int32_t i = 0; i < 3; ++i) {
|
|
- OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
|
|
- OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
|
|
- ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
|
|
- ort_.ReleaseTensorTypeAndShapeInfo(info);
|
|
- }
|
|
-
|
|
- SetOutput(context, 0, outer_dims, input_ids);
|
|
- SetOutput(context, 1, outer_dims, attention_mask);
|
|
- SetOutput(context, 2, outer_dims, token_type_ids);
|
|
-}
|
|
-
|
|
-void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
|
- return CreateKernelImpl(api, info);
|
|
-}
|
|
-
|
|
-const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
|
|
-
|
|
-size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
|
|
- return 1;
|
|
-}
|
|
-
|
|
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
-}
|
|
-
|
|
-size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
|
|
- return 3;
|
|
-}
|
|
-
|
|
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
|
|
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
-}
|
|
diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp
|
|
index 6dfcd84..10565e4 100644
|
|
--- a/operators/tokenizer/bert_tokenizer.hpp
|
|
+++ b/operators/tokenizer/bert_tokenizer.hpp
|
|
@@ -3,12 +3,11 @@
|
|
|
|
#pragma once
|
|
|
|
+#include <memory>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
-#include "ocos.h"
|
|
#include "ustring.h"
|
|
-#include "string_utils.h"
|
|
-#include "string_tensor.h"
|
|
+#include "string_utils_onnx.h"
|
|
#include "basic_tokenizer.hpp"
|
|
|
|
class BertTokenizerVocab final {
|
|
@@ -89,33 +88,4 @@ class BertTokenizer final {
|
|
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
|
};
|
|
|
|
-struct KernelBertTokenizer : BaseKernel {
|
|
- KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
|
- void Compute(OrtKernelContext* context);
|
|
|
|
- protected:
|
|
- std::unique_ptr<BertTokenizer> tokenizer_;
|
|
-};
|
|
-
|
|
-struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
|
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
|
- const char* GetName() const;
|
|
- size_t GetInputTypeCount() const;
|
|
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
|
- size_t GetOutputTypeCount() const;
|
|
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
|
-};
|
|
-
|
|
-struct KernelHfBertTokenizer : KernelBertTokenizer {
|
|
- KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
|
- void Compute(OrtKernelContext* context);
|
|
-};
|
|
-
|
|
-struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
|
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
|
- const char* GetName() const;
|
|
- size_t GetInputTypeCount() const;
|
|
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
|
- size_t GetOutputTypeCount() const;
|
|
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
|
-};
|