Merge pull request #885 from ozanarmagan/v0.25

Semantic Search & Hybrid Search
This commit is contained in:
Kishore Nallan 2023-03-03 17:00:31 +05:30 committed by GitHub
commit 1ad6044bec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 2161 additions and 21 deletions

3
.gitignore vendored
View File

@ -12,4 +12,5 @@ cmake-build-release
/bazel-*
typesense-server-data/
.clwb/.bazelproject
.vscode/settings.json
.vscode/settings.json
/onnxruntime-prefix

131
BUILD
View File

@ -37,6 +37,7 @@ cc_library(
}),
deps = [
":headers",
":onnxruntime_lib",
"@com_github_brpc_braft//:braft",
"@com_github_brpc_brpc//:brpc",
"@com_github_google_glog//:glog",
@ -134,3 +135,133 @@ cc_test(
"ROOT_DIR="
],
)
load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake")
__POSTFIX = """
mkdir -p $INSTALLDIR/lib
mkdir -p $INSTALLDIR/lib/_deps
mkdir -p $INSTALLDIR/lib/_deps/onnx-build
mkdir -p $INSTALLDIR/lib/_deps/re2-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
mkdir -p $INSTALLDIR/lib/_deps/protobuf-build
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog
mkdir -p $INSTALLDIR/lib/_deps/google_nsync-build
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx.a $INSTALLDIR/lib/_deps/onnx-build
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx_proto.a $INSTALLDIR/lib/_deps/onnx-build
cp $BUILD_TMPDIR/_deps/re2-build/libre2.a $INSTALLDIR/lib/_deps/re2-build
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_base.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_city.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
cp $BUILD_TMPDIR/_deps/google_nsync-build/libnsync_cpp.a $INSTALLDIR/lib/_deps/google_nsync-build
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/libcpuinfo.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build
"""
cmake(
name = "onnxruntime",
lib_source = "@onnx_runtime//:all_srcs",
cache_entries = {'onnxruntime_RUN_ONNX_TESTS':'OFF',
'onnxruntime_GENERATE_TEST_REPORTS':'ON',
'onnxruntime_USE_MIMALLOC':'OFF',
'onnxruntime_ENABLE_PYTHON':'OFF',
'onnxruntime_BUILD_CSHARP':'OFF',
'onnxruntime_BUILD_JAVA':'OFF',
'onnxruntime_BUILD_NODEJS':'OFF',
'onnxruntime_BUILD_OBJC':'OFF',
'onnxruntime_BUILD_SHARED_LIB':'OFF',
'onnxruntime_BUILD_APPLE_FRAMEWORK':'OFF',
'onnxruntime_USE_DNNL':'OFF',
'onnxruntime_USE_NNAPI_BUILTIN':'OFF',
'onnxruntime_USE_RKNPU':'OFF',
'onnxruntime_USE_LLVM':'OFF',
'onnxruntime_ENABLE_MICROSOFT_INTERNAL':'OFF',
'onnxruntime_USE_VITISAI':'OFF',
'onnxruntime_USE_TENSORRT':'OFF',
'onnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS':'ON',
'onnxruntime_USE_TENSORRT_BUILTIN_PARSER':'OFF',
'onnxruntime_TENSORRT_PLACEHOLDER_BUILDER':'OFF', 'onnxruntime_USE_TVM':'OFF',
'onnxruntime_TVM_CUDA_RUNTIME':'OFF', 'onnxruntime_TVM_USE_HASH':'OFF',
'onnxruntime_USE_MIGRAPHX':'OFF', 'onnxruntime_CROSS_COMPILING':'OFF',
'onnxruntime_DISABLE_CONTRIB_OPS':'OFF', 'onnxruntime_DISABLE_ML_OPS':'OFF',
'onnxruntime_DISABLE_RTTI':'OFF', 'onnxruntime_DISABLE_EXCEPTIONS':'OFF',
'onnxruntime_MINIMAL_BUILD':'OFF', 'onnxruntime_EXTENDED_MINIMAL_BUILD':'OFF',
'onnxruntime_MINIMAL_BUILD_CUSTOM_OPS':'OFF', 'onnxruntime_REDUCED_OPS_BUILD':'OFF',
'onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS':'OFF', 'onnxruntime_USE_DML':'OFF',
'onnxruntime_USE_WINML':'OFF', 'onnxruntime_BUILD_MS_EXPERIMENTAL_OPS':'OFF',
'onnxruntime_USE_TELEMETRY':'OFF', 'onnxruntime_ENABLE_LTO':'OFF',
'onnxruntime_USE_ACL':'OFF', 'onnxruntime_USE_ACL_1902':'OFF',
'onnxruntime_USE_ACL_1905':'OFF', 'onnxruntime_USE_ACL_1908':'OFF',
'onnxruntime_USE_ACL_2002':'OFF', 'onnxruntime_USE_ARMNN':'OFF',
'onnxruntime_ARMNN_RELU_USE_CPU':'ON', 'onnxruntime_ARMNN_BN_USE_CPU':'ON',
'onnxruntime_ENABLE_NVTX_PROFILE':'OFF', 'onnxruntime_ENABLE_TRAINING':'OFF',
'onnxruntime_ENABLE_TRAINING_OPS':'OFF', 'onnxruntime_ENABLE_TRAINING_APIS':'OFF',
'onnxruntime_ENABLE_CPU_FP16_OPS':'OFF', 'onnxruntime_USE_NCCL':'OFF',
'onnxruntime_BUILD_BENCHMARKS':'OFF', 'onnxruntime_USE_ROCM':'OFF',
'Onnxruntime_GCOV_COVERAGE':'OFF', 'onnxruntime_USE_MPI':'ON',
'onnxruntime_ENABLE_MEMORY_PROFILE':'OFF',
'onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO':'OFF',
'onnxruntime_BUILD_WEBASSEMBLY':'OFF', 'onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING':'ON',
'onnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING':'ON',
'onnxruntime_ENABLE_WEBASSEMBLY_THREADS':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO':'OFF',
'onnxruntime_ENABLE_WEBASSEMBLY_PROFILING':'OFF',
'onnxruntime_ENABLE_EAGER_MODE':'OFF', 'onnxruntime_ENABLE_LAZY_TENSOR':'OFF',
'onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS':'OFF', 'onnxruntime_ENABLE_CUDA_PROFILING':'OFF',
'onnxruntime_ENABLE_ROCM_PROFILING':'OFF', 'onnxruntime_USE_XNNPACK':'OFF',
'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF',
'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF',
'CMAKE_BUILD_TYPE':'Release',
},
working_directory="cmake",
build_args= [
"--config Release",
"-j6"
],
tags=["requires-network","no-sandbox"],
features=["-default_compile_flags","-fno-canonical-system-headers", "-Wno-builtin-macro-redefined"],
out_static_libs=[ "libonnxruntime_session.a",
"libonnxruntime_optimizer.a",
"libonnxruntime_providers.a",
"libonnxruntime_util.a",
"libonnxruntime_framework.a",
"libonnxruntime_graph.a",
"libonnxruntime_mlas.a",
"libonnxruntime_common.a",
"libonnxruntime_flatbuffers.a",
"_deps/onnx-build/libonnx.a",
"_deps/onnx-build/libonnx_proto.a",
"_deps/re2-build/libre2.a",
"_deps/abseil_cpp-build/absl/base/libabsl_base.a",
"_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a",
"_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a",
"_deps/abseil_cpp-build/absl/hash/libabsl_hash.a",
"_deps/abseil_cpp-build/absl/hash/libabsl_city.a",
"_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a",
"_deps/google_nsync-build/libnsync_cpp.a",
"_deps/pytorch_cpuinfo-build/libcpuinfo.a",
"_deps/pytorch_cpuinfo-build/deps/clog/libclog.a",
],
postfix_script=__POSTFIX
)
cc_library(
name = "onnxruntime_lib",
linkopts = select({
"@platforms//os:linux": ["-static-libstdc++", "-static-libgcc"],
"//conditions:default": [],
}),
deps = ["//:onnxruntime", "@onnx_runtime_extensions//:operators"],
includes= ["onnxruntime/include/onnxruntime"]
)

View File

@ -1,11 +1,10 @@
cmake_minimum_required(VERSION 3.17.5)
cmake_minimum_required(VERSION 3.24.0)
project(typesense)
cmake_policy(SET CMP0074 NEW)
cmake_policy(SET CMP0003 NEW)
set(USE_SANTINIZER OFF)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -Wno-unused-parameter -Werror=return-type -O2 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -Wextra -Wno-unused-parameter -Werror=return-type -std=c++17 -O0 -g")
set(DEP_ROOT_DIR ${CMAKE_SOURCE_DIR}/external-${CMAKE_SYSTEM_NAME})
@ -62,6 +61,8 @@ ELSE()
set(ENV{CMAKE_FIND_LIBRARY_SUFFIXES} ".a")
ENDIF()
include(cmake/onnxruntime.cmake)
include(cmake/onnxruntime_ext.cmake)
include(cmake/For.cmake)
include(cmake/H2O.cmake)
include(cmake/RocksDB.cmake)
@ -109,6 +110,9 @@ include_directories(${DEP_ROOT_DIR}/${LRUCACHE_NAME}/include)
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/build/include)
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/data)
include_directories(${DEP_ROOT_DIR}/${HNSW_NAME})
include_directories(${DEP_ROOT_DIR}/${ONNX_NAME}/include/onnxruntime)
include_directories(${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir)
link_directories(/usr/local/lib)
link_directories(${DEP_ROOT_DIR}/${GTEST_NAME}/googletest/build)
@ -119,6 +123,18 @@ link_directories(${DEP_ROOT_DIR}/${ICONV_NAME}/lib/.libs)
link_directories(${DEP_ROOT_DIR}/${JEMALLOC_NAME}/lib)
link_directories(${DEP_ROOT_DIR}/${S2_NAME}/build)
link_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/build/lib)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/onnx-build)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/re2-build)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/base)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/protobuf-build)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/container)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/hash)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build/deps)
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build/deps/clog)
set(JEMALLOC_ROOT_DIR "${DEP_ROOT_DIR}/${JEMALLOC_NAME}")
FIND_PACKAGE(Jemalloc REQUIRED)
@ -128,6 +144,51 @@ add_executable(search ${SRC_FILES} src/main/main.cpp)
add_executable(benchmark ${SRC_FILES} src/main/benchmark.cpp)
add_executable(typesense-test ${SRC_FILES} ${TEST_FILES})
add_library(ONNX_SESSION IMPORTED STATIC)
set_target_properties(ONNX_SESSION PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_session.a)
add_library(ONNX_OPT STATIC IMPORTED)
set_target_properties(ONNX_OPT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_optimizer.a)
add_library(ONNX_PRO STATIC IMPORTED)
set_target_properties(ONNX_PRO PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_providers.a)
add_library(ONNX_UTL STATIC IMPORTED)
set_target_properties(ONNX_UTL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_util.a)
add_library(ONNX_FRM STATIC IMPORTED)
set_target_properties(ONNX_FRM PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_framework.a)
add_library(ONNX_GRP STATIC IMPORTED)
set_target_properties(ONNX_GRP PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_graph.a)
add_library(ONNX_MLS STATIC IMPORTED)
set_target_properties(ONNX_MLS PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_mlas.a)
add_library(ONNX_CMN STATIC IMPORTED)
set_target_properties(ONNX_CMN PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_common.a)
add_library(ONNX_FLT STATIC IMPORTED)
set_target_properties(ONNX_FLT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_flatbuffers.a)
add_library(ONNX STATIC IMPORTED)
set_target_properties(ONNX PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/onnx-build/libonnx.a)
add_library(ONNX_PRT STATIC IMPORTED)
set_target_properties(ONNX_PRT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/onnx-build/libonnx_proto.a)
add_library(ONNX_PRTL STATIC IMPORTED)
set_target_properties(ONNX_PRTL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/protobuf-build/libprotobuf-lite.a)
add_library(ONNX_RE STATIC IMPORTED)
set_target_properties(ONNX_RE PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/re2-build/libre2.a)
add_library(ABSL STATIC IMPORTED)
set_target_properties(ABSL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/base/libabsl_base.a)
add_library(ABSL_DEL STATIC IMPORTED)
set_target_properties(ABSL_DEL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a)
add_library(ABSL_RW STATIC IMPORTED)
set_target_properties(ABSL_RW PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a)
add_library(ABSL_HSH STATIC IMPORTED)
set_target_properties(ABSL_HSH PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a)
add_library(ABSL_CTY STATIC IMPORTED)
set_target_properties(ABSL_CTY PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_city.a)
add_library(ABSL_LL STATIC IMPORTED)
set_target_properties(ABSL_LL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a)
add_library(NSYNC STATIC IMPORTED)
set_target_properties(NSYNC PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/google_nsync-build/libnsync_cpp.a)
add_library(CPUI STATIC IMPORTED)
set_target_properties(CPUI PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/pytorch_cpuinfo-build/libcpuinfo.a)
add_library(CLOG STATIC IMPORTED)
set_target_properties(CLOG PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a)
target_compile_definitions(
typesense-server PRIVATE
TYPESENSE_VERSION="${TYPESENSE_VERSION}"
@ -171,9 +232,27 @@ set(CORE_LIBS kakasi h2o-evloop braft brpc iconv ${ICU_ALL_LIBRARIES} ${CURL_LIB
${LevelDB_LIBRARIES} ${ROCKSDB_LIBS}
glog ${GFLAGS_LIBRARIES} ${PROTOBUF_LIBRARIES} ${STACKTRACE_LIBS}
${OPENSSL_LIBRARIES} ${ZLIB_LIBRARIES} ${JEMALLOC_LIBRARIES}
${SYSTEM_LIBS} pthread dl ${STD_LIB})
${SYSTEM_LIBS} pthread dl ${STD_LIB} ONNX_SESSION ONNX_OPT ONNX_PRO ONNX_UTL ONNX_FRM ONNX_GRP ONNX_MLS ONNX_CMN ONNX_FLT ONNX ONNX_PRT ONNX_PRTL ONNX_RE ABSL ABSL_DEL ABSL_RW ABSL_HSH ABSL_CTY ABSL_LL NSYNC CPUI CLOG)
target_link_libraries(typesense-server ${CORE_LIBS})
target_link_libraries(search ${CORE_LIBS})
target_link_libraries(benchmark ${CORE_LIBS})
target_link_libraries(typesense-test ${CORE_LIBS} gtest gtest_main)
add_dependencies(typesense-server onnxruntime)
add_dependencies(typesense-test onnxruntime)
add_dependencies(benchmark onnxruntime)
add_dependencies(search onnxruntime)
# add source files from ${DEP_ROOT_DIR}/${ONNX_EXT_NAME} directory to targets
set(ONNX_EXT_SRC_FILES ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/ustring.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/string_utils_onnx.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/base64.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/tokenizer/bert_tokenizer.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/tokenizer/basic_tokenizer.cc)
set_source_files_properties(${ONNX_EXT_SRC_FILES} PROPERTIES GENERATED TRUE)
target_sources(typesense-server PRIVATE ${ONNX_EXT_SRC_FILES})
target_sources(typesense-test PRIVATE ${ONNX_EXT_SRC_FILES})
target_sources(benchmark PRIVATE ${ONNX_EXT_SRC_FILES})
target_sources(search PRIVATE ${ONNX_EXT_SRC_FILES})
add_dependencies(typesense-server onnxruntime_ext)
add_dependencies(typesense-test onnxruntime_ext)
add_dependencies(benchmark onnxruntime_ext)
add_dependencies(search onnxruntime_ext)

View File

@ -21,7 +21,11 @@ http_archive(
load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
rules_foreign_cc_dependencies()
# This sets up some common toolchains for building targets. For more details, please see
# https://bazelbuild.github.io/rules_foreign_cc/0.9.0/flatten.html#rules_foreign_cc_dependencies
rules_foreign_cc_dependencies(
cmake_version="3.25.0",
ninja_version="1.11.1")
# brpc and its dependencies
git_repository(
@ -34,6 +38,25 @@ git_repository(
remote = "https://github.com/apache/incubator-brpc.git",
)
new_git_repository(
name="onnx_runtime",
branch= "rel-1.14.0",
build_file = "//bazel:onnxruntime.BUILD",
init_submodules= 1,
recursive_init_submodules= 1,
remote= "https://github.com/microsoft/onnxruntime",
patches=["//bazel:onnx.patch"],
)
new_git_repository(
name = "onnx_runtime_extensions",
build_file = "//bazel:onnxruntime_extensions.BUILD",
remote = "https://github.com/microsoft/onnxruntime-extensions",
commit = "81e7799c69044c745239202085eb0a98f102937b",
patches=["//bazel:onnx_ext.patch"],
)
new_git_repository(
name = "com_github_madler_zlib",
build_file = "//bazel:zlib.BUILD",

View File

@ -43,3 +43,253 @@
feature_configuration = feature_configuration,
action_name = ACTION_NAMES.cpp_link_static_library,
diff --git a/toolchains/built_toolchains.bzl b/toolchains/built_toolchains.bzl
index 5e59e79..ddf63a5 100644
--- toolchains/built_toolchains.bzl
+++ toolchains/built_toolchains.bzl
@@ -28,6 +28,7 @@ _CMAKE_SRCS = {
"3.22.4": [["https://github.com/Kitware/CMake/releases/download/v3.22.4/cmake-3.22.4.tar.gz"], "cmake-3.22.4", "5c55d0b0bc4c191549e3502b8f99a4fe892077611df22b4178cc020626e22a47"],
"3.23.1": [["https://github.com/Kitware/CMake/releases/download/v3.23.1/cmake-3.23.1.tar.gz"], "cmake-3.23.1", "33fd10a8ec687a4d0d5b42473f10459bb92b3ae7def2b745dc10b192760869f3"],
"3.23.2": [["https://github.com/Kitware/CMake/releases/download/v3.23.2/cmake-3.23.2.tar.gz"], "cmake-3.23.2", "f316b40053466f9a416adf981efda41b160ca859e97f6a484b447ea299ff26aa"],
+ "3.25.0": [["https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0.tar.gz"], "cmake-3.25.0", "306463f541555da0942e6f5a0736560f70c487178b9d94a5ae7f34d0538cdd48"],
}
# buildifier: disable=unnamed-macro
@@ -438,6 +439,18 @@ def _ninja_toolchain(version, register_toolchains):
native.register_toolchains(
"@rules_foreign_cc//toolchains:built_ninja_toolchain",
)
+ if version == "1.11.1":
+ maybe(
+ http_archive,
+ name = "ninja_build_src",
+ build_file_content = _ALL_CONTENT,
+ sha256 = "31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea",
+ strip_prefix = "ninja-1.11.1",
+ urls = [
+ "https://github.com/ninja-build/ninja/archive/v1.11.1.tar.gz",
+ ],
+ )
+ return
if version == "1.11.0":
maybe(
http_archive,
diff --git a/toolchains/prebuilt_toolchains.bzl b/toolchains/prebuilt_toolchains.bzl
index dabfb95..d9c38b4 100644
--- toolchains/prebuilt_toolchains.bzl
+++ toolchains/prebuilt_toolchains.bzl
@@ -67,6 +67,115 @@ def prebuilt_toolchains(cmake_version, ninja_version, register_toolchains):
_make_toolchains(register_toolchains)
def _cmake_toolchains(version, register_toolchains):
+ if "3.25.0" == version:
+ maybe(
+ http_archive,
+ name = "cmake-3.25.0-linux-aarch64",
+ urls = [
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-linux-aarch64.tar.gz",
+ ],
+ sha256 = "27da36d6debe9b30f5c498554ae40cd621a55736f5f2ae2618ed95722a59965a",
+ strip_prefix = "cmake-3.25.0-linux-aarch64",
+ build_file_content = _CMAKE_BUILD_FILE.format(
+ bin = "cmake",
+ env = "{}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "cmake-3.25.0-linux-x86_64",
+ urls = [
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-linux-x86_64.tar.gz",
+ ],
+ sha256 = "ac634d6f0a81d7089adc7be5acff66a6bee3b08615f9a947858ce92a9ef59c8b",
+ strip_prefix = "cmake-3.25.0-linux-x86_64",
+ build_file_content = _CMAKE_BUILD_FILE.format(
+ bin = "cmake",
+ env = "{}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "cmake-3.25.0-macos-universal",
+ urls = [
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-macos-universal.tar.gz",
+ ],
+ sha256 = "c088e761534a2078cd9d0581d39f02d3f9ed05302e33135b55c6d619b263b4c3",
+ strip_prefix = "cmake-3.25.0-macos-universal/CMake.app/Contents",
+ build_file_content = _CMAKE_BUILD_FILE.format(
+ bin = "cmake",
+ env = "{}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "cmake-3.25.0-windows-i386",
+ urls = [
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-windows-i386.zip",
+ ],
+ sha256 = "ddd115257a19ff3dd18fc63f32a00ae742f8b62d2e39bc354629903512f99783",
+ strip_prefix = "cmake-3.25.0-windows-i386",
+ build_file_content = _CMAKE_BUILD_FILE.format(
+ bin = "cmake.exe",
+ env = "{}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "cmake-3.25.0-windows-x86_64",
+ urls = [
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-windows-x86_64.zip",
+ ],
+ sha256 = "b46030c10cab1170355952f9ac59f7e6dabc248070fc53f15dff11d4ed2910f8",
+ strip_prefix = "cmake-3.25.0-windows-x86_64",
+ build_file_content = _CMAKE_BUILD_FILE.format(
+ bin = "cmake.exe",
+ env = "{}",
+ ),
+ )
+
+ # buildifier: leave-alone
+ maybe(
+ prebuilt_toolchains_repository,
+ name = "cmake_3.25.0_toolchains",
+ repos = {
+ "cmake-3.25.0-linux-aarch64": [
+ "@platforms//cpu:aarch64",
+ "@platforms//os:linux",
+ ],
+ "cmake-3.25.0-linux-x86_64": [
+ "@platforms//cpu:x86_64",
+ "@platforms//os:linux",
+ ],
+ "cmake-3.25.0-macos-universal": [
+ "@platforms//os:macos",
+ ],
+ "cmake-3.25.0-windows-i386": [
+ "@platforms//cpu:x86_32",
+ "@platforms//os:windows",
+ ],
+ "cmake-3.25.0-windows-x86_64": [
+ "@platforms//cpu:x86_64",
+ "@platforms//os:windows",
+ ],
+ },
+ tool = "cmake",
+ )
+
+ if register_toolchains:
+ native.register_toolchains(
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-linux-aarch64_toolchain",
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-linux-x86_64_toolchain",
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-macos-universal_toolchain",
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-windows-i386_toolchain",
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-windows-x86_64_toolchain",
+ )
+
+ return
if "3.23.2" == version:
maybe(
http_archive,
@@ -4196,6 +4305,78 @@ def _cmake_toolchains(version, register_toolchains):
fail("Unsupported version: " + str(version))
def _ninja_toolchains(version, register_toolchains):
+ if "1.11.1" == version:
+ maybe(
+ http_archive,
+ name = "ninja_1.11.1_linux",
+ urls = [
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-linux.zip",
+ ],
+ sha256 = "b901ba96e486dce377f9a070ed4ef3f79deb45f4ffe2938f8e7ddc69cfb3df77",
+ strip_prefix = "",
+ build_file_content = _NINJA_BUILD_FILE.format(
+ bin = "ninja",
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "ninja_1.11.1_mac",
+ urls = [
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-mac.zip",
+ ],
+ sha256 = "482ecb23c59ae3d4f158029112de172dd96bb0e97549c4b1ca32d8fad11f873e",
+ strip_prefix = "",
+ build_file_content = _NINJA_BUILD_FILE.format(
+ bin = "ninja",
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
+ ),
+ )
+
+ maybe(
+ http_archive,
+ name = "ninja_1.11.1_win",
+ urls = [
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-win.zip",
+ ],
+ sha256 = "524b344a1a9a55005eaf868d991e090ab8ce07fa109f1820d40e74642e289abc",
+ strip_prefix = "",
+ build_file_content = _NINJA_BUILD_FILE.format(
+ bin = "ninja.exe",
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
+ ),
+ )
+
+ # buildifier: leave-alone
+ maybe(
+ prebuilt_toolchains_repository,
+ name = "ninja_1.11.1_toolchains",
+ repos = {
+ "ninja_1.11.1_linux": [
+ "@platforms//cpu:x86_64",
+ "@platforms//os:linux",
+ ],
+ "ninja_1.11.1_mac": [
+ "@platforms//cpu:x86_64",
+ "@platforms//os:macos",
+ ],
+ "ninja_1.11.1_win": [
+ "@platforms//cpu:x86_64",
+ "@platforms//os:windows",
+ ],
+ },
+ tool = "ninja",
+ )
+
+ if register_toolchains:
+ native.register_toolchains(
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_linux_toolchain",
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_mac_toolchain",
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_win_toolchain",
+ )
+
+ return
if "1.11.0" == version:
maybe(
http_archive,
diff --git a/toolchains/prebuilt_toolchains.py b/toolchains/prebuilt_toolchains.py
index 5288b27..a193021 100755
--- toolchains/prebuilt_toolchains.py
+++ toolchains/prebuilt_toolchains.py
@@ -10,6 +10,7 @@ CMAKE_SHA256_URL_TEMPLATE = "https://cmake.org/files/v{minor}/cmake-{full}-SHA-2
CMAKE_URL_TEMPLATE = "https://github.com/Kitware/CMake/releases/download/v{full}/{file}"
CMAKE_VERSIONS = [
+ "3.25.0",
"3.23.2",
"3.23.1",
"3.22.4",
@@ -116,6 +117,7 @@ NINJA_TARGETS = {
}
NINJA_VERSIONS = (
+ "1.11.1",
"1.10.2",
"1.10.1",
"1.10.0",

12
bazel/onnx.patch Normal file
View File

@ -0,0 +1,12 @@
diff --git a/cmake/external/helper_functions.cmake b/cmake/external/helper_functions.cmake
index 88b46890b7..d090499971 100644
--- cmake/external/helper_functions.cmake
+++ cmake/external/helper_functions.cmake
@@ -117,7 +117,6 @@ macro(onnxruntime_fetchcontent_makeavailable)
${__cmake_contentName}
${__cmake_contentNameLower}
)
- find_package(${__cmake_contentName} ${__cmake_fpArgs})
list(POP_BACK __cmake_fcCurrentNameStack
__cmake_contentNameLower
__cmake_contentName

363
bazel/onnx_ext.patch Normal file
View File

@ -0,0 +1,363 @@
diff --git a/operators/string_tensor.cc b/operators/string_tensor.cc
index 3d49e64..84975f6 100644
--- operators/string_tensor.cc
+++ operators/string_tensor.cc
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "string_utils.h"
+#include "string_utils_onnx.h"
#include "string_tensor.h"
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc
similarity index 99%
rename from operators/string_utils.cc
rename to operators/string_utils_onnx.cc
index ecb6713..91dbe76 100644
--- operators/string_utils.cc
+++ operators/string_utils_onnx.cc
@@ -2,7 +2,7 @@
#include "farmhash.h"
#endif
-#include "string_utils.h"
+#include "string_utils_onnx.h"
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
std::vector<std::string_view> result;
diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h
similarity index 89%
rename from operators/string_utils.h
rename to operators/string_utils_onnx.h
index 5653fbd..6556666 100644
--- operators/string_utils.h
+++ operators/string_utils_onnx.h
@@ -4,7 +4,6 @@
#include <iostream>
#include <sstream>
#include <vector>
-#include "ocos.h"
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
@@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
ss << "]";
}
-template <>
-inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
- MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
-}
-
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";
diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc
index 324a774..00eac2b 100644
--- operators/tokenizer/basic_tokenizer.cc
+++ operators/tokenizer/basic_tokenizer.cc
@@ -1,9 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "string_utils.h"
+#include "string_utils_onnx.h"
#include "basic_tokenizer.hpp"
-#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
@@ -81,52 +80,3 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
push_current_token_and_clear();
return result;
}
-
-KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
- bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
- bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
-
- tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
-}
-
-void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- OrtTensorDimensions dimensions(ort_, input);
- if (dimensions.size() != 1 && dimensions[0] != 1) {
- ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
- }
-
- OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
- std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
-
- FillTensorDataString(api_, ort_, context, result, output);
-}
-
-void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-};
-
-const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
-
-size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
- return 1;
-};
-
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-};
-
-size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
- return 1;
-};
-
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-};
diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp
index 046499e..9fd6f1a 100644
--- operators/tokenizer/basic_tokenizer.hpp
+++ operators/tokenizer/basic_tokenizer.hpp
@@ -3,8 +3,7 @@
#pragma once
-#include "ocos.h"
-#include "string_utils.h"
+#include "string_utils_onnx.h"
#include "ustring.h"
class BasicTokenizer {
@@ -19,19 +18,3 @@ class BasicTokenizer {
bool tokenize_punctuation_;
bool remove_control_chars_;
};
-
-struct KernelBasicTokenizer : BaseKernel {
- KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
- private:
- std::shared_ptr<BasicTokenizer> tokenizer_;
-};
-
-struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};
diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc
index b860ba6..9f43c5e 100644
--- operators/tokenizer/bert_tokenizer.cc
+++ operators/tokenizer/bert_tokenizer.cc
@@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
- ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
+ std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token);
+ return -1;
}
return it->second;
@@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
}
}
-KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
- std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
- bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
- std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
- std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
- std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
- std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
- std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
- std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
- std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
- int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
-
- tokenizer_ = std::make_unique<BertTokenizer>(
- vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
- ustring(sep_token), ustring(pad_token), ustring(cls_token),
- ustring(mask_token), tokenize_chinese_chars, strip_accents,
- ustring(suffix_indicator), max_len, truncation_strategy_name);
-}
-
-void KernelBertTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- if (input_data.size() != 1 && input_data.size() != 2) {
- ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
- }
- std::vector<int64_t> input_ids;
- std::vector<int64_t> token_type_ids;
-
- if (input_data.size() == 1) {
- std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
- tokenizer_->Truncate(encoded);
- input_ids = tokenizer_->AddSpecialToken(encoded);
- token_type_ids = tokenizer_->GenerateTypeId(encoded);
- } else {
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
- input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
- token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
- }
-
- std::vector<int64_t> attention_mask(input_ids.size(), 1);
-
- std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
-
- SetOutput(context, 0, output_dim, input_ids);
- SetOutput(context, 1, output_dim, token_type_ids);
- SetOutput(context, 2, output_dim, attention_mask);
-}
-
-void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-}
-
-const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
-
-size_t CustomOpBertTokenizer::GetInputTypeCount() const {
- return 1;
-}
-
-ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-}
-
-size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
- return 3;
-}
-
-ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
-}
-
-KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
-
-void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- if (input_data.size() != 2) {
- ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
- }
-
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
- std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
- std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
- std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
-
- const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
- const std::vector<int64_t> inner_dims{1LL};
- for (int32_t i = 0; i < 3; ++i) {
- OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
- OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
- ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
- ort_.ReleaseTensorTypeAndShapeInfo(info);
- }
-
- SetOutput(context, 0, outer_dims, input_ids);
- SetOutput(context, 1, outer_dims, attention_mask);
- SetOutput(context, 2, outer_dims, token_type_ids);
-}
-
-void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-}
-
-const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
-
-size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
- return 1;
-}
-
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-}
-
-size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
- return 3;
-}
-
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
-}
diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp
index 6dfcd84..10565e4 100644
--- operators/tokenizer/bert_tokenizer.hpp
+++ operators/tokenizer/bert_tokenizer.hpp
@@ -3,12 +3,11 @@
#pragma once
+#include <memory>
#include <unordered_map>
#include <vector>
-#include "ocos.h"
#include "ustring.h"
-#include "string_utils.h"
-#include "string_tensor.h"
+#include "string_utils_onnx.h"
#include "basic_tokenizer.hpp"
class BertTokenizerVocab final {
@@ -89,33 +88,4 @@ class BertTokenizer final {
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
};
-struct KernelBertTokenizer : BaseKernel {
- KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
- protected:
- std::unique_ptr<BertTokenizer> tokenizer_;
-};
-
-struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};
-
-struct KernelHfBertTokenizer : KernelBertTokenizer {
- KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
-};
-
-struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};

12
bazel/onnxruntime.BUILD Normal file
View File

@ -0,0 +1,12 @@
filegroup(
name = "all_srcs",
srcs = glob(["**"]),
visibility = ["//visibility:public"],
)
cc_library(
name = "hdrs",
hdrs = glob(["include/onnxruntime/**/*.h"]),
strip_include_prefix = "include/onnxruntime",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,25 @@
cc_library(
name = "headers",
hdrs = glob(["includes/**"]),
strip_include_prefix = "includes",
visibility = ["//visibility:public"],
deps = [
"@github_nlohmann_json//:json",
],
)
cc_library(
name = "operators",
hdrs = glob(["operators/base64.h",
"operators/string_utils*",
"operators/ustring.h",
"operators/tokenizer/bert_tokenizer.hpp",
"operators/tokenizer/basic_tokenizer.hpp"]),
srcs = glob(["operators/base64.cc",
"operators/string_utils*",
"operators/ustring.cc",
"operators/tokenizer/bert_tokenizer.cc",
"operators/tokenizer/basic_tokenizer.cc"]),
strip_include_prefix = "operators",
visibility = ["//visibility:public"],
)

View File

@ -22,7 +22,7 @@ if [[ "$@" == *"--depclean"* ]]; then
fi
cmake -DTYPESENSE_VERSION=$TYPESENSE_VERSION -DCMAKE_BUILD_TYPE=Release -H$PROJECT_DIR -B$PROJECT_DIR/$BUILD_DIR
make typesense-server typesense-test -C $PROJECT_DIR/$BUILD_DIR
cmake --build $PROJECT_DIR/$BUILD_DIR --target typesense-server --target typesense-test
if [[ "$@" == *"--package-binary"* ]]; then
OS_FAMILY=$(echo `uname` | awk '{print tolower($0)}')

12
cmake/onnx.patch Normal file
View File

@ -0,0 +1,12 @@
diff --git a/cmake/external/helper_functions.cmake b/cmake/external/helper_functions.cmake
index 88b46890b7..d090499971 100644
--- a/cmake/external/helper_functions.cmake
+++ b/cmake/external/helper_functions.cmake
@@ -117,7 +117,6 @@ macro(onnxruntime_fetchcontent_makeavailable)
${__cmake_contentName}
${__cmake_contentNameLower}
)
- find_package(${__cmake_contentName} ${__cmake_fpArgs})
list(POP_BACK __cmake_fcCurrentNameStack
__cmake_contentNameLower
__cmake_contentName

351
cmake/onnx_ext.patch Normal file
View File

@ -0,0 +1,351 @@
diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc
similarity index 99%
rename from operators/string_utils.cc
rename to operators/string_utils_onnx.cc
index ecb6713..91dbe76 100644
--- a/operators/string_utils.cc
+++ b/operators/string_utils_onnx.cc
@@ -2,7 +2,7 @@
#include "farmhash.h"
#endif
-#include "string_utils.h"
+#include "string_utils_onnx.h"
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
std::vector<std::string_view> result;
diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h
similarity index 89%
rename from operators/string_utils.h
rename to operators/string_utils_onnx.h
index 5653fbd..6556666 100644
--- a/operators/string_utils.h
+++ b/operators/string_utils_onnx.h
@@ -4,7 +4,6 @@
#include <iostream>
#include <sstream>
#include <vector>
-#include "ocos.h"
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
@@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
ss << "]";
}
-template <>
-inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
- MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
-}
-
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";
diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc
index 324a774..00eac2b 100644
--- a/operators/tokenizer/basic_tokenizer.cc
+++ b/operators/tokenizer/basic_tokenizer.cc
@@ -1,9 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "string_utils.h"
+#include "string_utils_onnx.h"
#include "basic_tokenizer.hpp"
-#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
@@ -81,52 +80,3 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
push_current_token_and_clear();
return result;
}
-
-KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
- bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
- bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
-
- tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
-}
-
-void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- OrtTensorDimensions dimensions(ort_, input);
- if (dimensions.size() != 1 && dimensions[0] != 1) {
- ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
- }
-
- OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
- std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
-
- FillTensorDataString(api_, ort_, context, result, output);
-}
-
-void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-};
-
-const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
-
-size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
- return 1;
-};
-
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-};
-
-size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
- return 1;
-};
-
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-};
diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp
index 046499e..9fd6f1a 100644
--- a/operators/tokenizer/basic_tokenizer.hpp
+++ b/operators/tokenizer/basic_tokenizer.hpp
@@ -3,8 +3,7 @@
#pragma once
-#include "ocos.h"
-#include "string_utils.h"
+#include "string_utils_onnx.h"
#include "ustring.h"
class BasicTokenizer {
@@ -19,19 +18,3 @@ class BasicTokenizer {
bool tokenize_punctuation_;
bool remove_control_chars_;
};
-
-struct KernelBasicTokenizer : BaseKernel {
- KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
- private:
- std::shared_ptr<BasicTokenizer> tokenizer_;
-};
-
-struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};
diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc
index b860ba6..9f43c5e 100644
--- a/operators/tokenizer/bert_tokenizer.cc
+++ b/operators/tokenizer/bert_tokenizer.cc
@@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
- ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
+ std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token);
+ return -1;
}
return it->second;
@@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
}
}
-KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
- std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
- bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
- std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
- std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
- std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
- std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
- std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
- std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
- std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
- int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
-
- tokenizer_ = std::make_unique<BertTokenizer>(
- vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
- ustring(sep_token), ustring(pad_token), ustring(cls_token),
- ustring(mask_token), tokenize_chinese_chars, strip_accents,
- ustring(suffix_indicator), max_len, truncation_strategy_name);
-}
-
-void KernelBertTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- if (input_data.size() != 1 && input_data.size() != 2) {
- ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
- }
- std::vector<int64_t> input_ids;
- std::vector<int64_t> token_type_ids;
-
- if (input_data.size() == 1) {
- std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
- tokenizer_->Truncate(encoded);
- input_ids = tokenizer_->AddSpecialToken(encoded);
- token_type_ids = tokenizer_->GenerateTypeId(encoded);
- } else {
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
- input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
- token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
- }
-
- std::vector<int64_t> attention_mask(input_ids.size(), 1);
-
- std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
-
- SetOutput(context, 0, output_dim, input_ids);
- SetOutput(context, 1, output_dim, token_type_ids);
- SetOutput(context, 2, output_dim, attention_mask);
-}
-
-void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-}
-
-const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
-
-size_t CustomOpBertTokenizer::GetInputTypeCount() const {
- return 1;
-}
-
-ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-}
-
-size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
- return 3;
-}
-
-ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
-}
-
-KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
-
-void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
- // Setup inputs
- const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
- std::vector<std::string> input_data;
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
-
- if (input_data.size() != 2) {
- ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
- }
-
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
- std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
- std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
- std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
-
- const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
- const std::vector<int64_t> inner_dims{1LL};
- for (int32_t i = 0; i < 3; ++i) {
- OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
- OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
- ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
- ort_.ReleaseTensorTypeAndShapeInfo(info);
- }
-
- SetOutput(context, 0, outer_dims, input_ids);
- SetOutput(context, 1, outer_dims, attention_mask);
- SetOutput(context, 2, outer_dims, token_type_ids);
-}
-
-void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
- return CreateKernelImpl(api, info);
-}
-
-const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
-
-size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
- return 1;
-}
-
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
-}
-
-size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
- return 3;
-}
-
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
-}
diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp
index 6dfcd84..10565e4 100644
--- a/operators/tokenizer/bert_tokenizer.hpp
+++ b/operators/tokenizer/bert_tokenizer.hpp
@@ -3,12 +3,11 @@
#pragma once
+#include <memory>
#include <unordered_map>
#include <vector>
-#include "ocos.h"
#include "ustring.h"
-#include "string_utils.h"
-#include "string_tensor.h"
+#include "string_utils_onnx.h"
#include "basic_tokenizer.hpp"
class BertTokenizerVocab final {
@@ -89,33 +88,4 @@ class BertTokenizer final {
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
};
-struct KernelBertTokenizer : BaseKernel {
- KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
- protected:
- std::unique_ptr<BertTokenizer> tokenizer_;
-};
-
-struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};
-
-struct KernelHfBertTokenizer : KernelBertTokenizer {
- KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
- void Compute(OrtKernelContext* context);
-};
-
-struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
- const char* GetName() const;
- size_t GetInputTypeCount() const;
- ONNXTensorElementDataType GetInputType(size_t index) const;
- size_t GetOutputTypeCount() const;
- ONNXTensorElementDataType GetOutputType(size_t index) const;
-};

19
cmake/onnxruntime.cmake Normal file
View File

@ -0,0 +1,19 @@
project(onnxruntme)
set(ONNX_NAME onnxruntime)
include(ExternalProject)
if(NOT EXISTS ${DEP_ROOT_DIR}/${ONNX_NAME})
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_NAME})
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_NAME}-build)
endif()
ExternalProject_Add(
onnxruntime
GIT_REPOSITORY https://github.com/microsoft/onnxruntime
GIT_TAG origin/rel-1.14.0
SOURCE_DIR ${DEP_ROOT_DIR}/${ONNX_NAME}
PATCH_COMMAND cd ${DEP_ROOT_DIR}/${ONNX_NAME} && git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx.patch -R --check
BINARY_DIR ${DEP_ROOT_DIR}/${ONNX_NAME}-build
CONFIGURE_COMMAND ${CMAKE_COMMAND} ${DEP_ROOT_DIR}/${ONNX_NAME}/cmake -B${DEP_ROOT_DIR}/${ONNX_NAME}-build -Donnxruntime_RUN_ONNX_TESTS=OFF -Donnxruntime_GENERATE_TEST_REPORTS=ON -Donnxruntime_USE_MIMALLOC=OFF -Donnxruntime_ENABLE_PYTHON=OFF -Donnxruntime_BUILD_CSHARP=OFF -Donnxruntime_BUILD_JAVA=OFF -Donnxruntime_BUILD_NODEJS=OFF -Donnxruntime_BUILD_OBJC=OFF -Donnxruntime_BUILD_SHARED_LIB=OFF -Donnxruntime_BUILD_APPLE_FRAMEWORK=OFF -Donnxruntime_USE_DNNL=OFF -Donnxruntime_USE_NNAPI_BUILTIN=OFF -Donnxruntime_USE_RKNPU=OFF -Donnxruntime_USE_LLVM=OFF -Donnxruntime_ENABLE_MICROSOFT_INTERNAL=OFF -Donnxruntime_USE_VITISAI=OFF -Donnxruntime_USE_TENSORRT=OFF -Donnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS=ON -Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=OFF -Donnxruntime_TENSORRT_PLACEHOLDER_BUILDER=OFF -Donnxruntime_USE_TVM=OFF -Donnxruntime_TVM_CUDA_RUNTIME=OFF -Donnxruntime_TVM_USE_HASH=OFF -Donnxruntime_USE_MIGRAPHX=OFF -Donnxruntime_CROSS_COMPILING=OFF -Donnxruntime_DISABLE_CONTRIB_OPS=OFF -Donnxruntime_DISABLE_ML_OPS=OFF -Donnxruntime_DISABLE_RTTI=OFF -Donnxruntime_DISABLE_EXCEPTIONS=OFF -Donnxruntime_MINIMAL_BUILD=OFF -Donnxruntime_EXTENDED_MINIMAL_BUILD=OFF -Donnxruntime_MINIMAL_BUILD_CUSTOM_OPS=OFF -Donnxruntime_REDUCED_OPS_BUILD=OFF -Donnxruntime_ENABLE_LANGUAGE_INTEROP_OPS=OFF -Donnxruntime_USE_DML=OFF -Donnxruntime_USE_WINML=OFF -Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=OFF -Donnxruntime_USE_TELEMETRY=OFF -Donnxruntime_ENABLE_LTO=OFF -Donnxruntime_USE_ACL=OFF -Donnxruntime_USE_ACL_1902=OFF -Donnxruntime_USE_ACL_1905=OFF -Donnxruntime_USE_ACL_1908=OFF -Donnxruntime_USE_ACL_2002=OFF -Donnxruntime_USE_ARMNN=OFF -Donnxruntime_ARMNN_RELU_USE_CPU=ON -Donnxruntime_ARMNN_BN_USE_CPU=ON -Donnxruntime_ENABLE_NVTX_PROFILE=OFF -Donnxruntime_ENABLE_TRAINING=OFF -Donnxruntime_ENABLE_TRAINING_OPS=OFF -Donnxruntime_ENABLE_TRAINING_APIS=OFF -Donnxruntime_ENABLE_CPU_FP16_OPS=OFF -Donnxruntime_USE_NCCL=OFF -Donnxruntime_BUILD_BENCHMARKS=OFF -Donnxruntime_USE_ROCM=OFF -Donnxruntime_GCOV_COVERAGE=OFF -Donnxruntime_USE_MPI=ON -Donnxruntime_ENABLE_MEMORY_PROFILE=OFF -Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=OFF -Donnxruntime_BUILD_WEBASSEMBLY=OFF -Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON -Donnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING=ON -Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=OFF -Donnxruntime_ENABLE_EAGER_MODE=OFF -Donnxruntime_ENABLE_LAZY_TENSOR=OFF -Donnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS=OFF -Donnxruntime_ENABLE_CUDA_PROFILING=OFF -Donnxruntime_ENABLE_ROCM_PROFILING=OFF -Donnxruntime_USE_XNNPACK=OFF -Donnxruntime_USE_CANN=OFF -DCMAKE_TLS_VERIFY=ON -DFETCHCONTENT_QUIET=OFF -Donnxruntime_PYBIND_EXPORT_OPSCHEMA=OFF -Donnxruntime_ENABLE_MEMLEAK_CHECKER=OFF -DCMAKE_BUILD_TYPE=Release
BUILD_COMMAND ${CMAKE_COMMAND} --build ${DEP_ROOT_DIR}/${ONNX_NAME}-build --config Release -- -j8
)

View File

@ -0,0 +1,21 @@
project(onnxruntme_ext)
set(ONNX_EXT_NAME onnxruntime_ext)
include(ExternalProject)
if(NOT EXISTS ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
else()
file(REMOVE_RECURSE ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
endif()
ExternalProject_Add(
onnxruntime_ext
GIT_REPOSITORY https://github.com/microsoft/onnxruntime-extensions
GIT_TAG origin/rel-0.6.0
SOURCE_DIR ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}
PATCH_COMMAND cd ${DEP_ROOT_DIR}/${ONNX_EXT_NAME} && git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx_ext.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx_ext.patch -R --check && cd operators && mkdir -p src_dir && mkdir -p src_dir/tokenizer && cp base64.h base64.cc string_utils_onnx.h string_utils_onnx.cc ustring.h ustring.cc src_dir && cp tokenizer/bert_tokenizer.hpp tokenizer/bert_tokenizer.cc tokenizer/basic_tokenizer.hpp tokenizer/basic_tokenizer.cc src_dir/tokenizer
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)

View File

@ -265,6 +265,8 @@ private:
const spp::sparse_hash_set<std::string>& exclude_fields,
tsl::htrie_set<char>& include_fields_full,
tsl::htrie_set<char>& exclude_fields_full) const;
public:
@ -345,6 +347,8 @@ public:
const DIRTY_VALUES dirty_values,
const std::string& id="");
Option<bool> embed_fields(nlohmann::json& document);
static uint32_t get_seq_id_from_key(const std::string & key);
Option<bool> get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const;
@ -402,7 +406,7 @@ public:
tsl::htrie_set<char>& include_fields_full,
tsl::htrie_set<char>& exclude_fields_full) const;
Option<nlohmann::json> search(const std::string & query, const std::vector<std::string> & search_fields,
Option<nlohmann::json> search(std::string query, const std::vector<std::string> & search_fields,
const std::string & filter_query, const std::vector<std::string> & facet_fields,
const std::vector<sort_by> & sort_fields, const std::vector<uint32_t>& num_typos,
size_t per_page = 10, size_t page = 1,

View File

@ -9,6 +9,7 @@
#include <sparsepp.h>
#include <tsl/htrie_map.h>
#include "json.hpp"
#include "text_embedder_manager.h"
namespace field_types {
// first field value indexed will determine the type
@ -48,6 +49,8 @@ namespace fields {
static const std::string num_dim = "num_dim";
static const std::string vec_dist = "vec_dist";
static const std::string reference = "reference";
static const std::string create_from = "create_from";
static const std::string model_name = "model_name";
}
enum vector_distance_type_t {
@ -73,6 +76,8 @@ struct field {
int nested_array;
size_t num_dim;
std::vector<std::string> create_from;
std::string model_name;
vector_distance_type_t vec_dist;
static constexpr int VAL_UNKNOWN = 2;
@ -83,9 +88,9 @@ struct field {
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "") :
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &create_from = {}, const std::string& model_name = "") :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference) {
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), create_from(create_from), model_name(model_name) {
set_computed_defaults(sort, infix);
}
@ -313,7 +318,12 @@ struct field {
if (!field.reference.empty()) {
field_val[fields::reference] = field.reference;
}
if(!field.create_from.empty()) {
field_val[fields::create_from] = field.create_from;
if(!field.model_name.empty()) {
field_val[fields::model_name] = field.model_name;
}
}
fields_json.push_back(field_val);
if(!field.has_valid_type()) {
@ -409,6 +419,59 @@ struct field {
size_t num_auto_detect_fields = 0;
for(nlohmann::json & field_json: fields_json) {
if(field_json.count(fields::create_from) != 0) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
if(!field_json[fields::create_from].is_array()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must be an array.");
}
if(field_json[fields::create_from].empty()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must have at least one element.");
}
for(auto& create_from_field : field_json[fields::create_from]) {
if(!create_from_field.is_string()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must be an array of strings.");
}
}
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
return Option<bool>(400, "Property `" + fields::create_from + "` is only allowed on a float array field.");
}
for(auto& create_from_field : field_json[fields::create_from]) {
bool flag = false;
for(const auto& field : fields_json) {
if(field[fields::name] == create_from_field) {
if(field[fields::type] != field_types::STRING) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
}
flag = true;
break;
}
}
if(!flag) {
for(const auto& field : the_fields) {
if(field.name == create_from_field) {
if(field.type != field_types::STRING) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
}
flag = true;
break;
}
}
}
if(!flag) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
}
}
}
auto op = json_field_to_field(enable_nested_fields,
field_json, the_fields, fallback_field_type, num_auto_detect_fields);
if(!op.ok()) {

View File

@ -24,6 +24,8 @@ private:
static size_t curl_write_async_done(void* context, curl_socket_t item);
static size_t curl_write_download(void *ptr, size_t size, size_t nmemb, FILE *stream);
static CURL* init_curl(const std::string& url, std::string& response);
static CURL* init_curl_async(const std::string& url, deferred_req_res_t* req_res, curl_slist*& chunk);
@ -43,6 +45,8 @@ public:
void init(const std::string & api_key);
static long download_file(const std::string& url, const std::string& file_path);
static long get_response(const std::string& url, std::string& response,
std::map<std::string, std::string>& res_headers, long timeout_ms=4000);

28
include/text_embedder.h Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#include <core/session/onnxruntime_cxx_api.h>
#include <tokenizer/bert_tokenizer.hpp>
#include <vector>
struct encoded_input_t {
std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids;
std::vector<int64_t> attention_mask;
};
class TextEmbedder {
public:
TextEmbedder(const std::string& model_path);
~TextEmbedder();
std::vector<float> Embed(const std::string& text);
static bool is_model_valid(const std::string& model_path, unsigned int& num_dims);
private:
Ort::Session* session_;
Ort::Env env_;
encoded_input_t Encode(const std::string& text);
BertTokenizer* tokenizer_;
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
std::string output_tensor_name;
};

View File

@ -0,0 +1,88 @@
#pragma once
#include <unordered_map>
#include "text_embedder.h"
// singleton class
class TextEmbedderManager {
public:
static TextEmbedderManager& get_instance() {
static TextEmbedderManager instance;
return instance;
}
TextEmbedderManager(TextEmbedderManager&&) = delete;
TextEmbedderManager& operator=(TextEmbedderManager&&) = delete;
TextEmbedderManager(const TextEmbedderManager&) = delete;
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
TextEmbedder* get_text_embedder(const std::string& model_path) {
if (text_embedders.find(model_path) == text_embedders.end()) {
text_embedders[model_path] = new TextEmbedder(model_path);
}
return text_embedders[model_path];
}
void delete_text_embedder(const std::string& model_path) {
if (text_embedders.find(model_path) != text_embedders.end()) {
delete text_embedders[model_path];
text_embedders.erase(model_path);
}
}
void delete_all_text_embedders() {
for (auto& text_embedder : text_embedders) {
delete text_embedder.second;
}
text_embedders.clear();
}
static void set_model_dir(const std::string& dir) {
model_dir = dir;
}
static const std::string& get_model_dir() {
return model_dir;
}
~TextEmbedderManager() {
delete_all_text_embedders();
}
static constexpr char* DEFAULT_MODEL_URL = "https://huggingface.co/typesense/models/resolve/main/e5-small/model.onnx";
static constexpr char* DEFAULT_MODEL_NAME = "ts-e5-small";
static constexpr char* DEFAULT_VOCAB_URL = "https://huggingface.co/typesense/models/resolve/main/e5-small/vocab.txt";
static constexpr char* DEFAULT_VOCAB_NAME = "vocab.txt";
inline static std::string model_dir = "";
inline static const std::string get_absolute_model_path(const std::string& model_name) {
if(model_dir.back() != '/') {
if(model_name.front() != '/') {
return model_dir + "/" + model_name + ".onnx";
} else {
return model_dir + model_name + ".onnx";
}
} else {
if(model_name.front() != '/') {
return model_dir + model_name + ".onnx";
} else {
return model_dir + "/" + model_name + ".onnx";
}
}
};
inline static const std::string get_absolute_vocab_path() {
if(model_dir.back() != '/') {
return model_dir + "/" + TextEmbedderManager::DEFAULT_VOCAB_NAME;
} else {
return model_dir + TextEmbedderManager::DEFAULT_VOCAB_NAME;
}
}
private:
TextEmbedderManager() = default;
std::unordered_map<std::string, TextEmbedder*> text_embedders;
};

View File

@ -11,6 +11,7 @@ class Config {
private:
std::string data_dir;
std::string log_dir;
std::string model_dir;
std::string api_key;
@ -108,6 +109,10 @@ public:
this->log_dir = log_dir;
}
void set_model_dir(const std::string & model_dir) {
this->model_dir = model_dir;
}
void set_api_key(const std::string & api_key) {
this->api_key = api_key;
}
@ -171,6 +176,10 @@ public:
return this->log_dir;
}
std::string get_model_dir() const {
return this->model_dir;
}
std::string get_api_key() const {
return this->api_key;
}
@ -565,6 +574,10 @@ public:
auto skip_writes_str = reader.Get("server", "skip-writes", "false");
this->skip_writes = (skip_writes_str == "true");
}
if(reader.Exists("server", "model-dir")) {
this->model_dir = reader.Get("server", "model-dir", "");
}
}
void load_config_cmd_args(cmdline::parser & options) {
@ -580,6 +593,11 @@ public:
this->api_key = options.get<std::string>("api-key");
}
if(options.exist("model-dir")) {
LOG(INFO) << "model-dir found";
this->model_dir = options.get<std::string>("model-dir");
}
// @deprecated
if(options.exist("search-only-api-key")) {
this->search_only_api_key = options.get<std::string>("search-only-api-key");

View File

@ -16,6 +16,7 @@
#include "logger.h"
#include "thread_local_vars.h"
#include "vector_query_ops.h"
#include "text_embedder_manager.h"
const std::string override_t::MATCH_EXACT = "exact";
const std::string override_t::MATCH_CONTAINS = "contains";
@ -71,6 +72,10 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
const std::string& id) {
try {
document = nlohmann::json::parse(json_str);
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
} catch(const std::exception& e) {
LOG(ERROR) << "JSON error: " << e.what();
return Option<doc_seq_id_t>(400, std::string("Bad JSON: ") + e.what());
@ -224,7 +229,15 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::sort] = coll_field.sort;
field_json[fields::infix] = coll_field.infix;
field_json[fields::locale] = coll_field.locale;
if(coll_field.create_from.size() > 0) {
field_json[fields::create_from] = coll_field.create_from;
}
if(coll_field.model_name.size() > 0) {
field_json[fields::model_name] = coll_field.model_name;
}
if(coll_field.num_dim > 0) {
field_json[fields::num_dim] = coll_field.num_dim;
}
@ -281,6 +294,7 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
const std::string & json_line = json_lines[i];
Option<doc_seq_id_t> doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id);
const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0;
index_record record(i, seq_id, document, operation, dirty_values);
@ -963,8 +977,9 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) {
bool exact_key_match = (kv.key().size() == field_name.size());
bool exact_primitive_match = exact_key_match && !kv.value().is_object();
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().create_from.size() > 0;
if(extract_only_string_fields && !kv.value().is_string()) {
if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) {
if(exact_primitive_match && !is_wildcard) {
// upstream needs to be returned an error
return Option<bool>(400, "Field `" + field_name + "` should be a string or a string array.");
@ -973,7 +988,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
continue;
}
if (exact_primitive_match || is_wildcard ||
if (exact_primitive_match || is_wildcard || text_embedding ||
// field_name prefix must be followed by a "." to indicate an object search
(enable_nested_fields && kv.key().size() > field_name.size() && kv.key()[field_name.size()] == '.')) {
processed_search_fields.push_back(kv.key());
@ -992,7 +1007,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
return Option<bool>(true);
}
Option<nlohmann::json> Collection::search(const std::string & raw_query,
Option<nlohmann::json> Collection::search(std::string raw_query,
const std::vector<std::string>& raw_search_fields,
const std::string & filter_query, const std::vector<std::string>& facet_fields,
const std::vector<sort_by> & sort_fields, const std::vector<uint32_t>& num_typos,
@ -1113,10 +1128,12 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
}
}
// validate search fields
std::vector<std::string> processed_search_fields;
std::vector<uint32_t> query_by_weights;
bool has_embedding_query = false;
for(size_t i = 0; i < raw_search_fields.size(); i++) {
const std::string& field_name = raw_search_fields[i];
if(field_name == "id") {
@ -1132,6 +1149,30 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
}
for(const auto& expanded_search_field: expanded_search_fields) {
auto search_field = search_schema.at(expanded_search_field);
if(search_field.num_dim > 0) {
if(has_embedding_query) {
std::string error = "Only one embedding field is allowed in the query.";
return Option<nlohmann::json>(400, error);
}
if(TextEmbedderManager::model_dir.empty()) {
std::string error = "Text embedding is not enabled. Please set `model-dir` at startup.";
return Option<nlohmann::json>(400, error);
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(search_field.model_name.size() > 0 ? search_field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(raw_query);
vector_query._reset();
vector_query.values = embedding;
vector_query.field_name = field_name;
has_embedding_query = true;
continue;
}
processed_search_fields.push_back(expanded_search_field);
if(!raw_query_by_weights.empty()) {
query_by_weights.push_back(raw_query_by_weights[i]);
@ -1139,14 +1180,19 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
}
}
std::string real_raw_query = raw_query;
if(has_embedding_query && processed_search_fields.size() == 0) {
raw_query = "*";
}
if(!query_by_weights.empty() && processed_search_fields.size() != query_by_weights.size()) {
std::string error = "Error, query_by_weights.size != query_by.size.";
return Option<nlohmann::json>(400, error);
}
for(const std::string & field_name: processed_search_fields) {
field search_field = search_schema.at(field_name);
if(!search_field.index) {
std::string error = "Field `" + field_name + "` is marked as a non-indexed field in the schema.";
return Option<nlohmann::json>(400, error);
@ -1593,7 +1639,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
}
nlohmann::json result = nlohmann::json::object();
result["found"] = total_found;
if(exclude_fields.count("out_of") == 0) {
@ -1799,7 +1844,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
wrapper_doc["geo_distance_meters"] = geo_distances;
}
if(!vector_query.field_name.empty()) {
if(!vector_query.field_name.empty() && query == "*") {
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
}
@ -1983,7 +2028,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
result["request_params"] = nlohmann::json::object();
result["request_params"]["collection_name"] = name;
result["request_params"]["per_page"] = per_page;
result["request_params"]["q"] = query;
result["request_params"]["q"] = real_raw_query;
//long long int timeMillis = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - begin).count();
//!LOG(INFO) << "Time taken for result calc: " << timeMillis << "us";
@ -4586,3 +4631,33 @@ Option<bool> Collection::populate_include_exclude_fields_lk(const spp::sparse_ha
std::shared_lock lock(mutex);
return populate_include_exclude_fields(include_fields, exclude_fields, include_fields_full, exclude_fields_full);
}
Option<bool> Collection::embed_fields(nlohmann::json& document) {
for(const auto& field : fields) {
if(field.create_from.size() > 0) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = document.find(field_name);
if(field_it != document.end()) {
if(field_it->is_string()) {
text_to_embed += field_it->get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` is not a string.");
}
} else {
return Option<bool>(400, "Field `" + field_name + "` not found in document.");
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
document[field.name] = embedding;
}
}
return Option<bool>(true);
}

View File

@ -58,6 +58,14 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field_obj[fields::reference] = "";
}
if(field_obj.count(fields::create_from) == 0) {
field_obj[fields::create_from] = std::vector<std::string>();
}
if(field_obj.count(fields::model_name) == 0) {
field_obj[fields::model_name] = "";
}
vector_distance_type_t vec_dist_type = vector_distance_type_t::cosine;
if(field_obj.count(fields::vec_dist) != 0) {
@ -70,7 +78,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet],
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference]);
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::create_from],
field_obj[fields::model_name]);
// value of `sort` depends on field type
if(field_obj.count(fields::sort) == 0) {
@ -200,7 +209,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
for(size_t coll_index = 0; coll_index < num_collections; coll_index++) {
const auto& collection_meta_json = collection_meta_jsons[coll_index];
nlohmann::json collection_meta = nlohmann::json::parse(collection_meta_json, nullptr, false);
if(collection_meta.is_discarded()) {
LOG(ERROR) << "Error while parsing collection meta, json: " << collection_meta_json;
return Option<bool>(500, "Error while parsing collection meta.");

View File

@ -1,6 +1,7 @@
#include <store.h>
#include "field.h"
#include "magic_enum.hpp"
#include "text_embedder_manager.h"
#include <stack>
#include <collection_manager.h>
@ -662,6 +663,38 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
if(field_json.count(fields::model_name) > 0 && field_json.count(fields::create_from) == 0) {
return Option<bool>(400, "Property `" + fields::model_name + "` can only be used with `" + fields::create_from + "`.");
}
if(field_json.count(fields::create_from) != 0) {
// If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model)
field_json[fields::num_dim] = static_cast<unsigned int>(384);
if(field_json.count(fields::model_name) != 0) {
unsigned int num_dim = 0;
if(!field_json[fields::model_name].is_string()) {
return Option<bool>(400, "Property `" + fields::model_name + "` must be a string.");
}
if(field_json[fields::model_name].get<std::string>().empty()) {
return Option<bool>(400, "Property `" + fields::model_name + "` must be a non-empty string.");
}
if(TextEmbedder::is_model_valid(field_json[fields::model_name].get<std::string>(), num_dim)) {
field_json[fields::num_dim] = num_dim;
} else {
return Option<bool>(400, "Property `" + fields::model_name + "` must be a valid model path.");
}
}
} else {
field_json[fields::create_from] = std::vector<std::string>();
}
if(field_json.count(fields::model_name) == 0) {
field_json[fields::model_name] = "";
}
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
if(field_json.count(fields::num_dim) == 0) {
@ -698,6 +731,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
if(field_json.count(fields::optional) == 0) {
// dynamic type fields are always optional
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
@ -741,7 +775,8 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
field_json[fields::reference])
field_json[fields::reference], field_json[fields::create_from].get<std::vector<std::string>>(),
field_json[fields::model_name])
);
if (!field_json[fields::reference].get<std::string>().empty()) {

View File

@ -352,3 +352,48 @@ size_t HttpClient::curl_write(char *contents, size_t size, size_t nmemb, std::st
s->append(contents, size*nmemb);
return size*nmemb;
}
size_t HttpClient::curl_write_download(void *ptr, size_t size, size_t nmemb, FILE *stream) {
size_t written = fwrite(ptr, size, nmemb, stream);
return written;
}
long HttpClient::download_file(const std::string& url, const std::string& file_path) {
CURL *curl = curl_easy_init();
if(curl == nullptr) {
return -1;
}
FILE *fp = fopen(file_path.c_str(), "wb");
if(fp == nullptr) {
LOG(ERROR) << "Unable to open file for writing: " << file_path;
return -1;
}
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT_MS, 4000);
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_download);
// follow redirects
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
CURLcode res_code = curl_easy_perform(curl);
if(res_code != CURLE_OK) {
LOG(ERROR) << "Unable to download file: " << url << " to " << file_path << " - " << curl_easy_strerror(res_code);
return -1;
}
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_easy_cleanup(curl);
fclose(fp);
return http_code;
}

View File

@ -3010,9 +3010,51 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
sort_order, field_values, geopoint_indices,
curated_ids_sorted, all_result_ids, all_result_ids_len, groups_processed);
if(!vector_query.field_name.empty()) {
// check at least one of sort fields is text match
bool has_text_match = false;
for(auto& sort_field : sort_fields_std) {
if(sort_field.name == sort_field_const::text_match) {
has_text_match = true;
break;
}
}
if(has_text_match) {
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
auto k = std::max<size_t>(vector_query.k, per_page * page);
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
}
for (const auto& dist_label : dist_labels) {
uint32 seq_id = dist_label.second;
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
auto score = (1.0 - vec_dist_score) * 100000000000.0;
auto found = topster->kv_map.find(seq_id);
if (found != topster->kv_map.end() && found->second->match_score_index >= 0 && found->second->match_score_index <= 2) {
found->second->scores[found->second->match_score_index] += score;
}
}
}
}
/*auto timeMillis0 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin0).count();
LOG(INFO) << "Time taken for multi-field aggregation: " << timeMillis0 << "ms";*/
}
//LOG(INFO) << "topster size: " << topster->size;

156
src/text_embedder.cpp Normal file
View File

@ -0,0 +1,156 @@
#include "text_embedder.h"
#include "text_embedder_manager.h"
#include "logger.h"
#include <string>
#include <fstream>
#include <sstream>
#include <filesystem>
TextEmbedder::TextEmbedder(const std::string& model_path) {
// create environment
Ort::SessionOptions session_options;
std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_path);
LOG(INFO) << "Loading model from: " << abs_path;
session_ = new Ort::Session(env_, abs_path.c_str(), session_options);
std::ifstream stream("vocab.txt");
std::stringstream ss;
ss << stream.rdbuf();
auto vocab_ = ss.str();
tokenizer_ = new BertTokenizer(vocab_, true, true, ustring("[UNK]"), ustring("[SEP]"), ustring("[PAD]"),
ustring("[CLS]"), ustring("[MASK]"), true, true, ustring("##"),512, std::string("longest_first"));
auto output_tensor_count = session_->GetOutputCount();
for (size_t i = 0; i < output_tensor_count; i++) {
auto shape = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
Ort::AllocatorWithDefaultOptions allocator;
output_tensor_name = std::string(session_->GetOutputNameAllocated(i, allocator).get());
break;
}
}
}
encoded_input_t TextEmbedder::Encode(const std::string& text) {
auto encoded = tokenizer_->Encode(tokenizer_->Tokenize(ustring(text)));
auto input_ids = tokenizer_->AddSpecialToken(encoded);
auto token_type_ids = tokenizer_->GenerateTypeId(encoded);
auto attention_mask = std::vector<int64_t>(input_ids.size(), 1);
return {input_ids, token_type_ids, attention_mask};
}
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs) {
std::vector<float> pooled_output;
for (int i = 0; i < inputs[0].size(); i++) {
float sum = 0;
for (int j = 0; j < inputs.size(); j++) {
sum += inputs[j][i];
}
pooled_output.push_back(sum / inputs.size());
}
return pooled_output;
}
std::vector<float> TextEmbedder::Embed(const std::string& text) {
auto encoded_input = Encode(text);
// create input tensor object from data values
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
std::vector<Ort::Value> input_tensors;
std::vector<std::vector<int64_t>> input_shapes;
std::vector<const char*> input_node_names = {"input_ids", "attention_mask", "token_type_ids"};
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.input_ids.size())});
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.attention_mask.size())});
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.token_type_ids.size())});
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.input_ids.data(), encoded_input.input_ids.size(), input_shapes[0].data(), input_shapes[0].size()));
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.attention_mask.data(), encoded_input.attention_mask.size(), input_shapes[1].data(), input_shapes[1].size()));
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.token_type_ids.data(), encoded_input.token_type_ids.size(), input_shapes[2].data(), input_shapes[2].size()));
//LOG(INFO) << "Running model";
// create output tensor object
std::vector<const char*> output_node_names = {output_tensor_name.c_str()};
std::vector<int64_t> output_node_dims {1, static_cast<int64_t>(encoded_input.input_ids.size()), 384}; // batch_size x seq_length x hidden_size
auto output_tensor = session_->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_tensors.size(), output_node_names.data(), output_node_names.size());
std::vector<std::vector<float>> output;
float* data = output_tensor[0].GetTensorMutableData<float>();
// print output tensor shape
auto shape = output_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
//LOG(INFO) << "Output tensor size: " << shape[0] << " x " << shape[1] << " x " << shape[2];
for (int i = 0; i < shape[1]; i++) {
std::vector<float> temp;
for (int j = 0; j < shape[2]; j++) {
temp.push_back(data[i * shape[2] + j]);
}
output.push_back(temp);
}
//LOG(INFO) << "Mean pooling";
auto pooled_output = mean_pooling(output);
return pooled_output;
}
TextEmbedder::~TextEmbedder() {
delete tokenizer_;
delete session_;
}
bool TextEmbedder::is_model_valid(const std::string& model_path, unsigned int& num_dims) {
LOG(INFO) << "Loading model: " << model_path;
Ort::SessionOptions session_options;
Ort::Env env;
std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_path);
if(!std::filesystem::exists(abs_path)) {
LOG(ERROR) << "Model file not found: " << abs_path;
return false;
}
Ort::Session session(env, abs_path.c_str(), session_options);
if(session.GetInputCount() != 3) {
LOG(ERROR) << "Invalid model: input count is not 3";
return false;
}
Ort::AllocatorWithDefaultOptions allocator;
auto input_ids_name = session.GetInputNameAllocated(0, allocator);
if (std::strcmp(input_ids_name.get(), "input_ids") != 0) {
LOG(ERROR) << "Invalid model: input_ids tensor not found";
return false;
}
auto attention_mask_name = session.GetInputNameAllocated(1, allocator);
if (std::strcmp(attention_mask_name.get(), "attention_mask") != 0) {
LOG(ERROR) << "Invalid model: attention_mask tensor not found";
return false;
}
auto token_type_ids_name = session.GetInputNameAllocated(2, allocator);
if (std::strcmp(token_type_ids_name.get(), "token_type_ids") != 0) {
LOG(ERROR) << "Invalid model: token_type_ids tensor not found";
return false;
}
auto output_tensor_count = session.GetOutputCount();
bool found_output_tensor = false;
for (size_t i = 0; i < output_tensor_count; i++) {
auto shape = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
num_dims = shape[2];
found_output_tensor = true;
break;
}
}
if (!found_output_tensor) {
LOG(ERROR) << "Invalid model: Output tensor not found";
return false;
}
return true;
}

View File

@ -16,6 +16,7 @@
#include "core_api.h"
#include "ratelimit_manager.h"
#include "text_embedder_manager.h"
#include "typesense_server_utils.h"
#include "file_utils.h"
#include "threadpool.h"
@ -66,6 +67,7 @@ void init_cmdline_options(cmdline::parser & options, int argc, char **argv) {
options.set_program_name("./typesense-server");
options.add<std::string>("data-dir", 'd', "Directory where data will be stored.", true);
options.add<std::string>("model-dir", '\0', "Directory where text embedding models will be stored.", false, "");
options.add<std::string>("api-key", 'a', "API key that allows all operations.", true);
options.add<std::string>("search-only-api-key", 's', "[DEPRECATED: use API key management end-point] API key that allows only searches.", false);
@ -452,6 +454,21 @@ int run_server(const Config & config, const std::string & version, void (*master
LOG(INFO) << "Failed to initialize rate limit manager: " << rate_limit_manager_init.error();
}
if(config.get_model_dir().size() > 0) {
LOG(INFO) << "Loading text embedding models from " << config.get_model_dir();
TextEmbedderManager::get_instance().set_model_dir(config.get_model_dir());
LOG(INFO) << "Downloading default model and vocab";
long res = httpClient.download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
if(res != 200) {
LOG(INFO) << "Failed to download default model: " << res;
}
res = httpClient.download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::get_absolute_vocab_path());
if(res != 200) {
LOG(INFO) << "Failed to download default vocab: " << res;
}
}
// first we start the peering service
ReplicationState replication_state(server, batch_indexer, &store,

View File

@ -5,6 +5,8 @@
#include <algorithm>
#include <collection_manager.h>
#include "collection.h"
#include "text_embedder_manager.h"
#include "http_client.h"
class CollectionAllFieldsTest : public ::testing::Test {
protected:
@ -27,6 +29,7 @@ protected:
virtual void SetUp() {
setupCollection();
system("mkdir -p models");
}
virtual void TearDown() {
@ -59,6 +62,8 @@ TEST_F(CollectionAllFieldsTest, IndexDocsWithoutSchema) {
while (std::getline(infile, json_line)) {
nlohmann::json document = nlohmann::json::parse(json_line);
Option<nlohmann::json> add_op = coll1->add(document.dump());
LOG(INFO) << "Add op: " << add_op.error();
ASSERT_TRUE(add_op.ok());
}
@ -1586,3 +1591,104 @@ TEST_F(CollectionAllFieldsTest, FieldNameMatchingRegexpShouldNotBeIndexedInNonAu
ASSERT_EQ(1, results["hits"].size());
}
TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) {
TextEmbedderManager::model_dir = "./models";
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = {"name"};
std::vector<field> fields;
std::string fallback_field_type;
auto arr = nlohmann::json::array();
arr.push_back(field_json);
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `create_from` can only be used with array of string fields.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) {
TextEmbedderManager::model_dir = std::string();
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = {"name"};
std::vector<field> fields;
std::string fallback_field_type;
auto arr = nlohmann::json::array();
arr.push_back(field_json);
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Text embedding is not enabled. Please set `model-dir` at startup.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, CreateFromNotArray) {
TextEmbedderManager::model_dir = "./models";
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = "name";
std::vector<field> fields;
std::string fallback_field_type;
auto arr = nlohmann::json::array();
arr.push_back(field_json);
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `create_from` must be an array.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) {
TextEmbedderManager::model_dir = "./models";
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["model_name"] = "model";
std::vector<field> fields;
std::string fallback_field_type;
auto arr = nlohmann::json::array();
arr.push_back(field_json);
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `model_name` can only be used with `create_from`.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) {
TextEmbedderManager::model_dir = "./models/";
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::get_absolute_vocab_path());
field embedding = field("embedding", field_types::FLOAT_ARRAY, false);
embedding.create_from.push_back("name");
std::vector<field> fields = {field("name", field_types::STRING, false),
embedding};
auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO);
ASSERT_TRUE(obj_coll_op.ok());
Collection* obj_coll = obj_coll_op.get();
nlohmann::json doc1;
doc1["name"] = "One Two Three";
auto add_res = obj_coll->add(doc1.dump());
ASSERT_TRUE(add_res.ok());
ASSERT_TRUE(add_res.get()["name"].is_string());
ASSERT_TRUE(add_res.get()["embedding"].is_array());
ASSERT_EQ(384, add_res.get()["embedding"].size());
// delete models folder
system("rm -rf ./models");
}

View File

@ -5,6 +5,8 @@
#include <algorithm>
#include <collection_manager.h>
#include "collection.h"
#include "text_embedder_manager.h"
#include "http_client.h"
class CollectionTest : public ::testing::Test {
protected:
@ -62,6 +64,7 @@ protected:
virtual void SetUp() {
setupCollection();
system("mkdir -p models");
}
virtual void TearDown() {
@ -4605,4 +4608,153 @@ TEST_F(CollectionTest, WildcardHighlightFullFields) {
ASSERT_EQ(0, result["hits"][0]["highlight"]["user"]["bio"].count("value"));
ASSERT_EQ(0, result["hits"][0]["highlight"]["user_name"].count("value"));
}
TEST_F(CollectionTest, SemanticSearchTest) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
]
})"_json;
TextEmbedderManager::model_dir = "./models/";
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json object;
object["name"] = "apple";
auto add_op = coll->add(object.dump());
ASSERT_TRUE(add_op.ok());
ASSERT_EQ("apple", add_op.get()["name"]);
ASSERT_EQ(384, add_op.get()["embedding"].size());
spp::sparse_hash_set<std::string> dummy_include_exclude;
auto search_res_op = coll->search("apple", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_TRUE(search_res_op.ok());
auto search_res = search_res_op.get();
ASSERT_EQ(1, search_res["found"].get<size_t>());
ASSERT_EQ(1, search_res["hits"].size());
ASSERT_EQ("apple", search_res["hits"][0]["document"]["name"].get<std::string>());
ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size());
// delete models folder
system("rm -rf ./models");
}
TEST_F(CollectionTest, InvalidSemanticSearch) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
]
})"_json;
TextEmbedderManager::model_dir = "./models/";
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
auto op = collectionManager.create_collection(schema);
LOG(INFO) << "op.error(): " << op.error();
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json object;
object["name"] = "apple";
auto add_op = coll->add(object.dump());
ASSERT_TRUE(add_op.ok());
LOG(INFO) << "add_op.get(): " << add_op.get().dump();
ASSERT_EQ("apple", add_op.get()["name"]);
ASSERT_EQ(384, add_op.get()["embedding"].size());
spp::sparse_hash_set<std::string> dummy_include_exclude;
auto search_res_op = coll->search("apple", {"embedding", "embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
// delete models folder
system("rm -rf ./models");
}
TEST_F(CollectionTest, HybridSearch) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
]
})"_json;
TextEmbedderManager::model_dir = "./models/";
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
auto op = collectionManager.create_collection(schema);
LOG(INFO) << collectionManager.get_collection("objects")->get_summary_json();
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json object;
object["name"] = "apple";
auto add_op = coll->add(object.dump());
LOG(INFO) << "hybrid search";
ASSERT_TRUE(add_op.ok());
ASSERT_EQ("apple", add_op.get()["name"]);
ASSERT_EQ(384, add_op.get()["embedding"].size());
spp::sparse_hash_set<std::string> dummy_include_exclude;
LOG(INFO) << "hybrid search 2";
auto search_res_op = coll->search("apple", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
LOG(INFO) << "hybrid search 3";
ASSERT_TRUE(search_res_op.ok());
auto search_res = search_res_op.get();
ASSERT_EQ(1, search_res["found"].get<size_t>());
ASSERT_EQ(1, search_res["hits"].size());
ASSERT_EQ("apple", search_res["hits"][0]["document"]["name"].get<std::string>());
ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size());
// delete models folder
system("rm -rf ./models");
}
TEST_F(CollectionTest, EmbedFielsTest) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
]
})"_json;
TextEmbedderManager::model_dir = "./models/";
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json object = R"({
"name": "apple"
})"_json;
auto embed_op = coll->embed_fields(object);
ASSERT_TRUE(embed_op.ok());
ASSERT_EQ("apple", object["name"]);
ASSERT_EQ(384, object["embedding"].get<std::vector<float>>().size());
// delete models folder
system("rm -rf ./models");
}