mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 06:02:26 +08:00
Merge pull request #1428 from ozanarmagan/v0.26-facets
Use new CLIPTokenizer
This commit is contained in:
commit
e98e1d76bd
1
BUILD
1
BUILD
@ -54,6 +54,7 @@ cc_library(
|
||||
"@rocksdb",
|
||||
"@s2geometry",
|
||||
"@hnsw",
|
||||
"@clip_tokenizer//:clip"
|
||||
# "@zip",
|
||||
],
|
||||
)
|
||||
|
@ -51,6 +51,13 @@ new_git_repository(
|
||||
patch_cmds= ["git submodule sync && git submodule foreach 'git fetch --tags' && git submodule update --init --remote"]
|
||||
)
|
||||
|
||||
new_git_repository(
|
||||
name="clip_tokenizer",
|
||||
branch="master",
|
||||
remote="https://github.com/ozanarmagan/clip_tokenizer_cpp",
|
||||
build_file = "//bazel:clip_tokenizer.BUILD",
|
||||
)
|
||||
|
||||
new_git_repository(
|
||||
name = "onnx_runtime_extensions",
|
||||
build_file = "//bazel:onnxruntime_extensions.BUILD",
|
||||
|
9
bazel/clip_tokenizer.BUILD
Normal file
9
bazel/clip_tokenizer.BUILD
Normal file
@ -0,0 +1,9 @@
|
||||
cc_library(
|
||||
name="clip",
|
||||
srcs=["clip_tokenizer.cpp"],
|
||||
hdrs= glob(["**/*.h"]),
|
||||
includes=["."],
|
||||
deps=["@icu"],
|
||||
visibility=["//visibility:public"],
|
||||
linkstatic=1
|
||||
)
|
@ -4,6 +4,7 @@
|
||||
#include <unordered_map>
|
||||
#include <sentencepiece_processor.h>
|
||||
#include <tokenizer/bert_tokenizer.hpp>
|
||||
#include <clip_tokenizer.h>
|
||||
#include <core/session/onnxruntime_cxx_api.h>
|
||||
#include <mutex>
|
||||
|
||||
@ -74,13 +75,12 @@ class XLMRobertaTokenizer : public TextEmbeddingTokenizer {
|
||||
}
|
||||
};
|
||||
|
||||
class CLIPTokenizer : public TextEmbeddingTokenizer {
|
||||
class CLIPTokenizerWrapper : public TextEmbeddingTokenizer {
|
||||
private:
|
||||
std::unique_ptr<Ort::Session> session_;
|
||||
Ort::Env env_;
|
||||
std::unique_ptr<CLIPTokenizer> clip_tokenizer_;
|
||||
std::mutex mutex_;
|
||||
public:
|
||||
CLIPTokenizer(const std::string& model_path);
|
||||
CLIPTokenizerWrapper(const std::string& vocab_path);
|
||||
encoded_input_t Encode(const std::string& text) override;
|
||||
virtual TokenizerType get_tokenizer_type() override {
|
||||
return TokenizerType::clip;
|
||||
|
@ -62,7 +62,7 @@ Option<bool> EmbedderManager::validate_and_init_local_model(const nlohmann::json
|
||||
}
|
||||
|
||||
std::string abs_path = EmbedderManager::get_absolute_model_path(
|
||||
EmbedderManager::get_model_name_without_namespace(model_name));
|
||||
EmbedderManager::get_model_name_without_namespace(model_name));
|
||||
|
||||
if(!std::filesystem::exists(abs_path)) {
|
||||
LOG(ERROR) << "Model file not found: " << abs_path;
|
||||
|
@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
|
||||
else if(tokenizer_type == TokenizerType::xlm_roberta) {
|
||||
tokenizer_ = std::make_unique<XLMRobertaTokenizer>(vocab_path);
|
||||
} else if(tokenizer_type == TokenizerType::clip) {
|
||||
tokenizer_ = std::make_unique<CLIPTokenizer>(EmbedderManager::get_model_subdir(model_name));
|
||||
tokenizer_ = std::make_unique<CLIPTokenizerWrapper>(vocab_path);
|
||||
output_tensor_name = "text_embeds";
|
||||
num_dim = 512;
|
||||
return;
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <sstream>
|
||||
#include "text_embedder_tokenizer.h"
|
||||
#include "logger.h"
|
||||
#include <unicode/normalizer2.h>
|
||||
|
||||
|
||||
BertTokenizerWrapper::BertTokenizerWrapper(const std::string& vocab_path) {
|
||||
@ -89,36 +90,21 @@ encoded_input_t XLMRobertaTokenizer::Encode(const std::string& text) {
|
||||
}
|
||||
|
||||
|
||||
CLIPTokenizer::CLIPTokenizer(const std::string& model_path) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.EnableOrtCustomOps();
|
||||
auto tokenizer_path= model_path + "/clip_tokenizer.onnx";
|
||||
LOG(INFO) << "Loading tokenizer from " << tokenizer_path;
|
||||
session_ = std::make_unique<Ort::Session>(env_, tokenizer_path.c_str(), session_options);
|
||||
CLIPTokenizerWrapper::CLIPTokenizerWrapper(const std::string& vocab_path) {
|
||||
try {
|
||||
clip_tokenizer_ = std::make_unique<CLIPTokenizer>(vocab_path);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(INFO) << "Failed to load CLIP tokenizer: " << e.what();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
encoded_input_t CLIPTokenizer::Encode(const std::string& text) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
encoded_input_t CLIPTokenizerWrapper::Encode(const std::string& text) {
|
||||
auto res = clip_tokenizer_->tokenize({text});
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
std::vector<Ort::Value> input_tensors;
|
||||
std::vector<int64_t> input_shape = {1};
|
||||
std::vector<const char*> input_names = {"string_input"};
|
||||
const char* const input_array[] = {text.c_str()};
|
||||
// convert vector int to vector int64_t
|
||||
std::vector<int64_t> input_ids(res.tokens[0].begin(), res.tokens[0].end());
|
||||
std::vector<int64_t> attention_mask(res.attention_mask[0].begin(), res.attention_mask[0].end());
|
||||
|
||||
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
input_tensor.FillStringTensor(input_array, 1U);
|
||||
input_tensors.push_back(std::move(input_tensor));
|
||||
|
||||
const std::vector<const char*> output_names = {"input_ids", "attention_mask"};
|
||||
auto output_tensors = session_->Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
|
||||
auto input_ids_tensor = output_tensors[0].GetTensorMutableData<int64_t>();
|
||||
auto attention_mask_tensor = output_tensors[1].GetTensorMutableData<int64_t>();
|
||||
|
||||
auto input_ids = std::vector<int64_t>(input_ids_tensor, input_ids_tensor + output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
auto attention_mask = std::vector<int64_t>(attention_mask_tensor, attention_mask_tensor + output_tensors[1].GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
|
||||
|
||||
return {std::move(input_ids), {}, std::move(attention_mask)};
|
||||
return {input_ids, {}, attention_mask};
|
||||
}
|
@ -3151,3 +3151,72 @@ TEST_F(CollectionVectorTest, TestInvalidImage) {
|
||||
ASSERT_EQ(add_op.error(), "Error while processing image");
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionVectorTest, TestCLIPTokenizerUnicode) {
|
||||
auto schema_json =
|
||||
R"({
|
||||
"name": "Images",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "image", "type": "image", "store": false},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["image"], "model_config": {"model_name": "ts/clip-vit-b-p32"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
|
||||
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
// test english
|
||||
auto results = coll->search("dog", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test chinese
|
||||
results = coll->search("狗", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test japanese
|
||||
results = coll->search("犬", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test korean
|
||||
results = coll->search("개", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test russian
|
||||
results = coll->search("собака", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test arabic
|
||||
results = coll->search("كلب", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
// test turkish
|
||||
results = coll->search("kö", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
results = coll->search("öğ", {"embedding"},
|
||||
"", {}, {}, {2}, 10,
|
||||
1, FREQUENCY, {true},
|
||||
0, spp::sparse_hash_set<std::string>()).get();
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user