Merge pull request #1428 from ozanarmagan/v0.26-facets

Use new CLIPTokenizer
This commit is contained in:
Kishore Nallan 2023-12-13 09:35:12 +05:30 committed by GitHub
commit e98e1d76bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 106 additions and 34 deletions

1
BUILD
View File

@ -54,6 +54,7 @@ cc_library(
"@rocksdb",
"@s2geometry",
"@hnsw",
"@clip_tokenizer//:clip"
# "@zip",
],
)

View File

@ -51,6 +51,13 @@ new_git_repository(
patch_cmds= ["git submodule sync && git submodule foreach 'git fetch --tags' && git submodule update --init --remote"]
)
new_git_repository(
name="clip_tokenizer",
branch="master",
remote="https://github.com/ozanarmagan/clip_tokenizer_cpp",
build_file = "//bazel:clip_tokenizer.BUILD",
)
new_git_repository(
name = "onnx_runtime_extensions",
build_file = "//bazel:onnxruntime_extensions.BUILD",

View File

@ -0,0 +1,9 @@
cc_library(
name="clip",
srcs=["clip_tokenizer.cpp"],
hdrs= glob(["**/*.h"]),
includes=["."],
deps=["@icu"],
visibility=["//visibility:public"],
linkstatic=1
)

View File

@ -4,6 +4,7 @@
#include <unordered_map>
#include <sentencepiece_processor.h>
#include <tokenizer/bert_tokenizer.hpp>
#include <clip_tokenizer.h>
#include <core/session/onnxruntime_cxx_api.h>
#include <mutex>
@ -74,13 +75,12 @@ class XLMRobertaTokenizer : public TextEmbeddingTokenizer {
}
};
class CLIPTokenizer : public TextEmbeddingTokenizer {
class CLIPTokenizerWrapper : public TextEmbeddingTokenizer {
private:
std::unique_ptr<Ort::Session> session_;
Ort::Env env_;
std::unique_ptr<CLIPTokenizer> clip_tokenizer_;
std::mutex mutex_;
public:
CLIPTokenizer(const std::string& model_path);
CLIPTokenizerWrapper(const std::string& vocab_path);
encoded_input_t Encode(const std::string& text) override;
virtual TokenizerType get_tokenizer_type() override {
return TokenizerType::clip;

View File

@ -62,7 +62,7 @@ Option<bool> EmbedderManager::validate_and_init_local_model(const nlohmann::json
}
std::string abs_path = EmbedderManager::get_absolute_model_path(
EmbedderManager::get_model_name_without_namespace(model_name));
EmbedderManager::get_model_name_without_namespace(model_name));
if(!std::filesystem::exists(abs_path)) {
LOG(ERROR) << "Model file not found: " << abs_path;

View File

@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
else if(tokenizer_type == TokenizerType::xlm_roberta) {
tokenizer_ = std::make_unique<XLMRobertaTokenizer>(vocab_path);
} else if(tokenizer_type == TokenizerType::clip) {
tokenizer_ = std::make_unique<CLIPTokenizer>(EmbedderManager::get_model_subdir(model_name));
tokenizer_ = std::make_unique<CLIPTokenizerWrapper>(vocab_path);
output_tensor_name = "text_embeds";
num_dim = 512;
return;

View File

@ -2,6 +2,7 @@
#include <sstream>
#include "text_embedder_tokenizer.h"
#include "logger.h"
#include <unicode/normalizer2.h>
BertTokenizerWrapper::BertTokenizerWrapper(const std::string& vocab_path) {
@ -89,36 +90,21 @@ encoded_input_t XLMRobertaTokenizer::Encode(const std::string& text) {
}
CLIPTokenizer::CLIPTokenizer(const std::string& model_path) {
Ort::SessionOptions session_options;
session_options.EnableOrtCustomOps();
auto tokenizer_path= model_path + "/clip_tokenizer.onnx";
LOG(INFO) << "Loading tokenizer from " << tokenizer_path;
session_ = std::make_unique<Ort::Session>(env_, tokenizer_path.c_str(), session_options);
CLIPTokenizerWrapper::CLIPTokenizerWrapper(const std::string& vocab_path) {
try {
clip_tokenizer_ = std::make_unique<CLIPTokenizer>(vocab_path);
} catch (const std::exception& e) {
LOG(INFO) << "Failed to load CLIP tokenizer: " << e.what();
throw;
}
}
encoded_input_t CLIPTokenizer::Encode(const std::string& text) {
std::unique_lock<std::mutex> lock(mutex_);
encoded_input_t CLIPTokenizerWrapper::Encode(const std::string& text) {
auto res = clip_tokenizer_->tokenize({text});
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> input_tensors;
std::vector<int64_t> input_shape = {1};
std::vector<const char*> input_names = {"string_input"};
const char* const input_array[] = {text.c_str()};
// convert vector int to vector int64_t
std::vector<int64_t> input_ids(res.tokens[0].begin(), res.tokens[0].end());
std::vector<int64_t> attention_mask(res.attention_mask[0].begin(), res.attention_mask[0].end());
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
input_tensor.FillStringTensor(input_array, 1U);
input_tensors.push_back(std::move(input_tensor));
const std::vector<const char*> output_names = {"input_ids", "attention_mask"};
auto output_tensors = session_->Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
auto input_ids_tensor = output_tensors[0].GetTensorMutableData<int64_t>();
auto attention_mask_tensor = output_tensors[1].GetTensorMutableData<int64_t>();
auto input_ids = std::vector<int64_t>(input_ids_tensor, input_ids_tensor + output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount());
auto attention_mask = std::vector<int64_t>(attention_mask_tensor, attention_mask_tensor + output_tensors[1].GetTensorTypeAndShapeInfo().GetElementCount());
return {std::move(input_ids), {}, std::move(attention_mask)};
return {input_ids, {}, attention_mask};
}

View File

@ -3151,3 +3151,72 @@ TEST_F(CollectionVectorTest, TestInvalidImage) {
ASSERT_EQ(add_op.error(), "Error while processing image");
}
TEST_F(CollectionVectorTest, TestCLIPTokenizerUnicode) {
auto schema_json =
R"({
"name": "Images",
"fields": [
{"name": "name", "type": "string"},
{"name": "image", "type": "image", "store": false},
{"name": "embedding", "type":"float[]", "embed":{"from": ["image"], "model_config": {"model_name": "ts/clip-vit-b-p32"}}}
]
})"_json;
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
// test english
auto results = coll->search("dog", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test chinese
results = coll->search("", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test japanese
results = coll->search("", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test korean
results = coll->search("", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test russian
results = coll->search("собака", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test arabic
results = coll->search("كلب", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// test turkish
results = coll->search("", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
results = coll->search("öğ", {"embedding"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
}