mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 06:40:30 +08:00
Merge pull request #885 from ozanarmagan/v0.25
Semantic Search & Hybrid Search
This commit is contained in:
commit
1ad6044bec
3
.gitignore
vendored
3
.gitignore
vendored
@ -12,4 +12,5 @@ cmake-build-release
|
||||
/bazel-*
|
||||
typesense-server-data/
|
||||
.clwb/.bazelproject
|
||||
.vscode/settings.json
|
||||
.vscode/settings.json
|
||||
/onnxruntime-prefix
|
131
BUILD
131
BUILD
@ -37,6 +37,7 @@ cc_library(
|
||||
}),
|
||||
deps = [
|
||||
":headers",
|
||||
":onnxruntime_lib",
|
||||
"@com_github_brpc_braft//:braft",
|
||||
"@com_github_brpc_brpc//:brpc",
|
||||
"@com_github_google_glog//:glog",
|
||||
@ -134,3 +135,133 @@ cc_test(
|
||||
"ROOT_DIR="
|
||||
],
|
||||
)
|
||||
|
||||
load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake")
|
||||
|
||||
__POSTFIX = """
|
||||
mkdir -p $INSTALLDIR/lib
|
||||
mkdir -p $INSTALLDIR/lib/_deps
|
||||
mkdir -p $INSTALLDIR/lib/_deps/onnx-build
|
||||
mkdir -p $INSTALLDIR/lib/_deps/re2-build
|
||||
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build
|
||||
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl
|
||||
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
|
||||
mkdir -p $INSTALLDIR/lib/_deps/protobuf-build
|
||||
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
|
||||
mkdir -p $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
|
||||
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build
|
||||
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps
|
||||
mkdir -p $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog
|
||||
mkdir -p $INSTALLDIR/lib/_deps/google_nsync-build
|
||||
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx.a $INSTALLDIR/lib/_deps/onnx-build
|
||||
cp $BUILD_TMPDIR/_deps/onnx-build/libonnx_proto.a $INSTALLDIR/lib/_deps/onnx-build
|
||||
cp $BUILD_TMPDIR/_deps/re2-build/libre2.a $INSTALLDIR/lib/_deps/re2-build
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_base.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/base
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/container
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_city.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
|
||||
cp $BUILD_TMPDIR/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a $INSTALLDIR/lib/_deps/abseil_cpp-build/absl/hash
|
||||
cp $BUILD_TMPDIR/_deps/google_nsync-build/libnsync_cpp.a $INSTALLDIR/lib/_deps/google_nsync-build
|
||||
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog
|
||||
cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/libcpuinfo.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build
|
||||
"""
|
||||
|
||||
cmake(
|
||||
name = "onnxruntime",
|
||||
lib_source = "@onnx_runtime//:all_srcs",
|
||||
cache_entries = {'onnxruntime_RUN_ONNX_TESTS':'OFF',
|
||||
'onnxruntime_GENERATE_TEST_REPORTS':'ON',
|
||||
'onnxruntime_USE_MIMALLOC':'OFF',
|
||||
'onnxruntime_ENABLE_PYTHON':'OFF',
|
||||
'onnxruntime_BUILD_CSHARP':'OFF',
|
||||
'onnxruntime_BUILD_JAVA':'OFF',
|
||||
'onnxruntime_BUILD_NODEJS':'OFF',
|
||||
'onnxruntime_BUILD_OBJC':'OFF',
|
||||
'onnxruntime_BUILD_SHARED_LIB':'OFF',
|
||||
'onnxruntime_BUILD_APPLE_FRAMEWORK':'OFF',
|
||||
'onnxruntime_USE_DNNL':'OFF',
|
||||
'onnxruntime_USE_NNAPI_BUILTIN':'OFF',
|
||||
'onnxruntime_USE_RKNPU':'OFF',
|
||||
'onnxruntime_USE_LLVM':'OFF',
|
||||
'onnxruntime_ENABLE_MICROSOFT_INTERNAL':'OFF',
|
||||
'onnxruntime_USE_VITISAI':'OFF',
|
||||
'onnxruntime_USE_TENSORRT':'OFF',
|
||||
'onnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS':'ON',
|
||||
'onnxruntime_USE_TENSORRT_BUILTIN_PARSER':'OFF',
|
||||
'onnxruntime_TENSORRT_PLACEHOLDER_BUILDER':'OFF', 'onnxruntime_USE_TVM':'OFF',
|
||||
'onnxruntime_TVM_CUDA_RUNTIME':'OFF', 'onnxruntime_TVM_USE_HASH':'OFF',
|
||||
'onnxruntime_USE_MIGRAPHX':'OFF', 'onnxruntime_CROSS_COMPILING':'OFF',
|
||||
'onnxruntime_DISABLE_CONTRIB_OPS':'OFF', 'onnxruntime_DISABLE_ML_OPS':'OFF',
|
||||
'onnxruntime_DISABLE_RTTI':'OFF', 'onnxruntime_DISABLE_EXCEPTIONS':'OFF',
|
||||
'onnxruntime_MINIMAL_BUILD':'OFF', 'onnxruntime_EXTENDED_MINIMAL_BUILD':'OFF',
|
||||
'onnxruntime_MINIMAL_BUILD_CUSTOM_OPS':'OFF', 'onnxruntime_REDUCED_OPS_BUILD':'OFF',
|
||||
'onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS':'OFF', 'onnxruntime_USE_DML':'OFF',
|
||||
'onnxruntime_USE_WINML':'OFF', 'onnxruntime_BUILD_MS_EXPERIMENTAL_OPS':'OFF',
|
||||
'onnxruntime_USE_TELEMETRY':'OFF', 'onnxruntime_ENABLE_LTO':'OFF',
|
||||
'onnxruntime_USE_ACL':'OFF', 'onnxruntime_USE_ACL_1902':'OFF',
|
||||
'onnxruntime_USE_ACL_1905':'OFF', 'onnxruntime_USE_ACL_1908':'OFF',
|
||||
'onnxruntime_USE_ACL_2002':'OFF', 'onnxruntime_USE_ARMNN':'OFF',
|
||||
'onnxruntime_ARMNN_RELU_USE_CPU':'ON', 'onnxruntime_ARMNN_BN_USE_CPU':'ON',
|
||||
'onnxruntime_ENABLE_NVTX_PROFILE':'OFF', 'onnxruntime_ENABLE_TRAINING':'OFF',
|
||||
'onnxruntime_ENABLE_TRAINING_OPS':'OFF', 'onnxruntime_ENABLE_TRAINING_APIS':'OFF',
|
||||
'onnxruntime_ENABLE_CPU_FP16_OPS':'OFF', 'onnxruntime_USE_NCCL':'OFF',
|
||||
'onnxruntime_BUILD_BENCHMARKS':'OFF', 'onnxruntime_USE_ROCM':'OFF',
|
||||
'Onnxruntime_GCOV_COVERAGE':'OFF', 'onnxruntime_USE_MPI':'ON',
|
||||
'onnxruntime_ENABLE_MEMORY_PROFILE':'OFF',
|
||||
'onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO':'OFF',
|
||||
'onnxruntime_BUILD_WEBASSEMBLY':'OFF', 'onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB':'OFF',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING':'ON',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING':'OFF',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING':'ON',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_THREADS':'OFF',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO':'OFF',
|
||||
'onnxruntime_ENABLE_WEBASSEMBLY_PROFILING':'OFF',
|
||||
'onnxruntime_ENABLE_EAGER_MODE':'OFF', 'onnxruntime_ENABLE_LAZY_TENSOR':'OFF',
|
||||
'onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS':'OFF', 'onnxruntime_ENABLE_CUDA_PROFILING':'OFF',
|
||||
'onnxruntime_ENABLE_ROCM_PROFILING':'OFF', 'onnxruntime_USE_XNNPACK':'OFF',
|
||||
'onnxruntime_USE_CANN':'OFF', 'CMAKE_TLS_VERIFY':'ON', 'FETCHCONTENT_QUIET':'OFF',
|
||||
'onnxruntime_PYBIND_EXPORT_OPSCHEMA':'OFF', 'onnxruntime_ENABLE_MEMLEAK_CHECKER':'OFF',
|
||||
'CMAKE_BUILD_TYPE':'Release',
|
||||
},
|
||||
working_directory="cmake",
|
||||
build_args= [
|
||||
"--config Release",
|
||||
"-j6"
|
||||
],
|
||||
tags=["requires-network","no-sandbox"],
|
||||
features=["-default_compile_flags","-fno-canonical-system-headers", "-Wno-builtin-macro-redefined"],
|
||||
out_static_libs=[ "libonnxruntime_session.a",
|
||||
"libonnxruntime_optimizer.a",
|
||||
"libonnxruntime_providers.a",
|
||||
"libonnxruntime_util.a",
|
||||
"libonnxruntime_framework.a",
|
||||
"libonnxruntime_graph.a",
|
||||
"libonnxruntime_mlas.a",
|
||||
"libonnxruntime_common.a",
|
||||
"libonnxruntime_flatbuffers.a",
|
||||
"_deps/onnx-build/libonnx.a",
|
||||
"_deps/onnx-build/libonnx_proto.a",
|
||||
"_deps/re2-build/libre2.a",
|
||||
"_deps/abseil_cpp-build/absl/base/libabsl_base.a",
|
||||
"_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a",
|
||||
"_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a",
|
||||
"_deps/abseil_cpp-build/absl/hash/libabsl_hash.a",
|
||||
"_deps/abseil_cpp-build/absl/hash/libabsl_city.a",
|
||||
"_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a",
|
||||
"_deps/google_nsync-build/libnsync_cpp.a",
|
||||
"_deps/pytorch_cpuinfo-build/libcpuinfo.a",
|
||||
"_deps/pytorch_cpuinfo-build/deps/clog/libclog.a",
|
||||
],
|
||||
postfix_script=__POSTFIX
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "onnxruntime_lib",
|
||||
linkopts = select({
|
||||
"@platforms//os:linux": ["-static-libstdc++", "-static-libgcc"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = ["//:onnxruntime", "@onnx_runtime_extensions//:operators"],
|
||||
includes= ["onnxruntime/include/onnxruntime"]
|
||||
)
|
||||
|
@ -1,11 +1,10 @@
|
||||
cmake_minimum_required(VERSION 3.17.5)
|
||||
cmake_minimum_required(VERSION 3.24.0)
|
||||
project(typesense)
|
||||
|
||||
cmake_policy(SET CMP0074 NEW)
|
||||
cmake_policy(SET CMP0003 NEW)
|
||||
|
||||
set(USE_SANTINIZER OFF)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -Wno-unused-parameter -Werror=return-type -O2 -g -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -Wextra -Wno-unused-parameter -Werror=return-type -std=c++17 -O0 -g")
|
||||
set(DEP_ROOT_DIR ${CMAKE_SOURCE_DIR}/external-${CMAKE_SYSTEM_NAME})
|
||||
@ -62,6 +61,8 @@ ELSE()
|
||||
set(ENV{CMAKE_FIND_LIBRARY_SUFFIXES} ".a")
|
||||
ENDIF()
|
||||
|
||||
include(cmake/onnxruntime.cmake)
|
||||
include(cmake/onnxruntime_ext.cmake)
|
||||
include(cmake/For.cmake)
|
||||
include(cmake/H2O.cmake)
|
||||
include(cmake/RocksDB.cmake)
|
||||
@ -109,6 +110,9 @@ include_directories(${DEP_ROOT_DIR}/${LRUCACHE_NAME}/include)
|
||||
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/build/include)
|
||||
include_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/data)
|
||||
include_directories(${DEP_ROOT_DIR}/${HNSW_NAME})
|
||||
include_directories(${DEP_ROOT_DIR}/${ONNX_NAME}/include/onnxruntime)
|
||||
include_directories(${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir)
|
||||
|
||||
|
||||
link_directories(/usr/local/lib)
|
||||
link_directories(${DEP_ROOT_DIR}/${GTEST_NAME}/googletest/build)
|
||||
@ -119,6 +123,18 @@ link_directories(${DEP_ROOT_DIR}/${ICONV_NAME}/lib/.libs)
|
||||
link_directories(${DEP_ROOT_DIR}/${JEMALLOC_NAME}/lib)
|
||||
link_directories(${DEP_ROOT_DIR}/${S2_NAME}/build)
|
||||
link_directories(${DEP_ROOT_DIR}/${KAKASI_NAME}/build/lib)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/onnx-build)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/re2-build)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/base)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/protobuf-build)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/container)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/abseil_cpp-build/absl/hash)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build/deps)
|
||||
link_directories(${DEP_ROOT_DIR}/${ONNX_NAME}-build/lib/_deps/pytorch_cpuinfo-build/deps/clog)
|
||||
|
||||
set(JEMALLOC_ROOT_DIR "${DEP_ROOT_DIR}/${JEMALLOC_NAME}")
|
||||
FIND_PACKAGE(Jemalloc REQUIRED)
|
||||
@ -128,6 +144,51 @@ add_executable(search ${SRC_FILES} src/main/main.cpp)
|
||||
add_executable(benchmark ${SRC_FILES} src/main/benchmark.cpp)
|
||||
add_executable(typesense-test ${SRC_FILES} ${TEST_FILES})
|
||||
|
||||
add_library(ONNX_SESSION IMPORTED STATIC)
|
||||
set_target_properties(ONNX_SESSION PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_session.a)
|
||||
add_library(ONNX_OPT STATIC IMPORTED)
|
||||
set_target_properties(ONNX_OPT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_optimizer.a)
|
||||
add_library(ONNX_PRO STATIC IMPORTED)
|
||||
set_target_properties(ONNX_PRO PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_providers.a)
|
||||
add_library(ONNX_UTL STATIC IMPORTED)
|
||||
set_target_properties(ONNX_UTL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_util.a)
|
||||
add_library(ONNX_FRM STATIC IMPORTED)
|
||||
set_target_properties(ONNX_FRM PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_framework.a)
|
||||
add_library(ONNX_GRP STATIC IMPORTED)
|
||||
set_target_properties(ONNX_GRP PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_graph.a)
|
||||
add_library(ONNX_MLS STATIC IMPORTED)
|
||||
set_target_properties(ONNX_MLS PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_mlas.a)
|
||||
add_library(ONNX_CMN STATIC IMPORTED)
|
||||
set_target_properties(ONNX_CMN PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_common.a)
|
||||
add_library(ONNX_FLT STATIC IMPORTED)
|
||||
set_target_properties(ONNX_FLT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/libonnxruntime_flatbuffers.a)
|
||||
add_library(ONNX STATIC IMPORTED)
|
||||
set_target_properties(ONNX PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/onnx-build/libonnx.a)
|
||||
add_library(ONNX_PRT STATIC IMPORTED)
|
||||
set_target_properties(ONNX_PRT PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/onnx-build/libonnx_proto.a)
|
||||
add_library(ONNX_PRTL STATIC IMPORTED)
|
||||
set_target_properties(ONNX_PRTL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/protobuf-build/libprotobuf-lite.a)
|
||||
add_library(ONNX_RE STATIC IMPORTED)
|
||||
set_target_properties(ONNX_RE PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/re2-build/libre2.a)
|
||||
add_library(ABSL STATIC IMPORTED)
|
||||
set_target_properties(ABSL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/base/libabsl_base.a)
|
||||
add_library(ABSL_DEL STATIC IMPORTED)
|
||||
set_target_properties(ABSL_DEL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/base/libabsl_throw_delegate.a)
|
||||
add_library(ABSL_RW STATIC IMPORTED)
|
||||
set_target_properties(ABSL_RW PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/container/libabsl_raw_hash_set.a)
|
||||
add_library(ABSL_HSH STATIC IMPORTED)
|
||||
set_target_properties(ABSL_HSH PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_hash.a)
|
||||
add_library(ABSL_CTY STATIC IMPORTED)
|
||||
set_target_properties(ABSL_CTY PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_city.a)
|
||||
add_library(ABSL_LL STATIC IMPORTED)
|
||||
set_target_properties(ABSL_LL PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/abseil_cpp-build/absl/hash/libabsl_low_level_hash.a)
|
||||
add_library(NSYNC STATIC IMPORTED)
|
||||
set_target_properties(NSYNC PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/google_nsync-build/libnsync_cpp.a)
|
||||
add_library(CPUI STATIC IMPORTED)
|
||||
set_target_properties(CPUI PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/pytorch_cpuinfo-build/libcpuinfo.a)
|
||||
add_library(CLOG STATIC IMPORTED)
|
||||
set_target_properties(CLOG PROPERTIES IMPORTED_LOCATION ${DEP_ROOT_DIR}/${ONNX_NAME}-build/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a)
|
||||
|
||||
target_compile_definitions(
|
||||
typesense-server PRIVATE
|
||||
TYPESENSE_VERSION="${TYPESENSE_VERSION}"
|
||||
@ -171,9 +232,27 @@ set(CORE_LIBS kakasi h2o-evloop braft brpc iconv ${ICU_ALL_LIBRARIES} ${CURL_LIB
|
||||
${LevelDB_LIBRARIES} ${ROCKSDB_LIBS}
|
||||
glog ${GFLAGS_LIBRARIES} ${PROTOBUF_LIBRARIES} ${STACKTRACE_LIBS}
|
||||
${OPENSSL_LIBRARIES} ${ZLIB_LIBRARIES} ${JEMALLOC_LIBRARIES}
|
||||
${SYSTEM_LIBS} pthread dl ${STD_LIB})
|
||||
${SYSTEM_LIBS} pthread dl ${STD_LIB} ONNX_SESSION ONNX_OPT ONNX_PRO ONNX_UTL ONNX_FRM ONNX_GRP ONNX_MLS ONNX_CMN ONNX_FLT ONNX ONNX_PRT ONNX_PRTL ONNX_RE ABSL ABSL_DEL ABSL_RW ABSL_HSH ABSL_CTY ABSL_LL NSYNC CPUI CLOG)
|
||||
|
||||
target_link_libraries(typesense-server ${CORE_LIBS})
|
||||
target_link_libraries(search ${CORE_LIBS})
|
||||
target_link_libraries(benchmark ${CORE_LIBS})
|
||||
target_link_libraries(typesense-test ${CORE_LIBS} gtest gtest_main)
|
||||
|
||||
add_dependencies(typesense-server onnxruntime)
|
||||
add_dependencies(typesense-test onnxruntime)
|
||||
add_dependencies(benchmark onnxruntime)
|
||||
add_dependencies(search onnxruntime)
|
||||
|
||||
# add source files from ${DEP_ROOT_DIR}/${ONNX_EXT_NAME} directory to targets
|
||||
set(ONNX_EXT_SRC_FILES ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/ustring.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/string_utils_onnx.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/base64.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/tokenizer/bert_tokenizer.cc ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}/operators/src_dir/tokenizer/basic_tokenizer.cc)
|
||||
set_source_files_properties(${ONNX_EXT_SRC_FILES} PROPERTIES GENERATED TRUE)
|
||||
target_sources(typesense-server PRIVATE ${ONNX_EXT_SRC_FILES})
|
||||
target_sources(typesense-test PRIVATE ${ONNX_EXT_SRC_FILES})
|
||||
target_sources(benchmark PRIVATE ${ONNX_EXT_SRC_FILES})
|
||||
target_sources(search PRIVATE ${ONNX_EXT_SRC_FILES})
|
||||
|
||||
add_dependencies(typesense-server onnxruntime_ext)
|
||||
add_dependencies(typesense-test onnxruntime_ext)
|
||||
add_dependencies(benchmark onnxruntime_ext)
|
||||
add_dependencies(search onnxruntime_ext)
|
25
WORKSPACE
25
WORKSPACE
@ -21,7 +21,11 @@ http_archive(
|
||||
|
||||
load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
|
||||
|
||||
rules_foreign_cc_dependencies()
|
||||
# This sets up some common toolchains for building targets. For more details, please see
|
||||
# https://bazelbuild.github.io/rules_foreign_cc/0.9.0/flatten.html#rules_foreign_cc_dependencies
|
||||
rules_foreign_cc_dependencies(
|
||||
cmake_version="3.25.0",
|
||||
ninja_version="1.11.1")
|
||||
|
||||
# brpc and its dependencies
|
||||
git_repository(
|
||||
@ -34,6 +38,25 @@ git_repository(
|
||||
remote = "https://github.com/apache/incubator-brpc.git",
|
||||
)
|
||||
|
||||
|
||||
new_git_repository(
|
||||
name="onnx_runtime",
|
||||
branch= "rel-1.14.0",
|
||||
build_file = "//bazel:onnxruntime.BUILD",
|
||||
init_submodules= 1,
|
||||
recursive_init_submodules= 1,
|
||||
remote= "https://github.com/microsoft/onnxruntime",
|
||||
patches=["//bazel:onnx.patch"],
|
||||
)
|
||||
|
||||
new_git_repository(
|
||||
name = "onnx_runtime_extensions",
|
||||
build_file = "//bazel:onnxruntime_extensions.BUILD",
|
||||
remote = "https://github.com/microsoft/onnxruntime-extensions",
|
||||
commit = "81e7799c69044c745239202085eb0a98f102937b",
|
||||
patches=["//bazel:onnx_ext.patch"],
|
||||
)
|
||||
|
||||
new_git_repository(
|
||||
name = "com_github_madler_zlib",
|
||||
build_file = "//bazel:zlib.BUILD",
|
||||
|
@ -43,3 +43,253 @@
|
||||
feature_configuration = feature_configuration,
|
||||
action_name = ACTION_NAMES.cpp_link_static_library,
|
||||
|
||||
diff --git a/toolchains/built_toolchains.bzl b/toolchains/built_toolchains.bzl
|
||||
index 5e59e79..ddf63a5 100644
|
||||
--- toolchains/built_toolchains.bzl
|
||||
+++ toolchains/built_toolchains.bzl
|
||||
@@ -28,6 +28,7 @@ _CMAKE_SRCS = {
|
||||
"3.22.4": [["https://github.com/Kitware/CMake/releases/download/v3.22.4/cmake-3.22.4.tar.gz"], "cmake-3.22.4", "5c55d0b0bc4c191549e3502b8f99a4fe892077611df22b4178cc020626e22a47"],
|
||||
"3.23.1": [["https://github.com/Kitware/CMake/releases/download/v3.23.1/cmake-3.23.1.tar.gz"], "cmake-3.23.1", "33fd10a8ec687a4d0d5b42473f10459bb92b3ae7def2b745dc10b192760869f3"],
|
||||
"3.23.2": [["https://github.com/Kitware/CMake/releases/download/v3.23.2/cmake-3.23.2.tar.gz"], "cmake-3.23.2", "f316b40053466f9a416adf981efda41b160ca859e97f6a484b447ea299ff26aa"],
|
||||
+ "3.25.0": [["https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0.tar.gz"], "cmake-3.25.0", "306463f541555da0942e6f5a0736560f70c487178b9d94a5ae7f34d0538cdd48"],
|
||||
}
|
||||
|
||||
# buildifier: disable=unnamed-macro
|
||||
@@ -438,6 +439,18 @@ def _ninja_toolchain(version, register_toolchains):
|
||||
native.register_toolchains(
|
||||
"@rules_foreign_cc//toolchains:built_ninja_toolchain",
|
||||
)
|
||||
+ if version == "1.11.1":
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "ninja_build_src",
|
||||
+ build_file_content = _ALL_CONTENT,
|
||||
+ sha256 = "31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea",
|
||||
+ strip_prefix = "ninja-1.11.1",
|
||||
+ urls = [
|
||||
+ "https://github.com/ninja-build/ninja/archive/v1.11.1.tar.gz",
|
||||
+ ],
|
||||
+ )
|
||||
+ return
|
||||
if version == "1.11.0":
|
||||
maybe(
|
||||
http_archive,
|
||||
diff --git a/toolchains/prebuilt_toolchains.bzl b/toolchains/prebuilt_toolchains.bzl
|
||||
index dabfb95..d9c38b4 100644
|
||||
--- toolchains/prebuilt_toolchains.bzl
|
||||
+++ toolchains/prebuilt_toolchains.bzl
|
||||
@@ -67,6 +67,115 @@ def prebuilt_toolchains(cmake_version, ninja_version, register_toolchains):
|
||||
_make_toolchains(register_toolchains)
|
||||
|
||||
def _cmake_toolchains(version, register_toolchains):
|
||||
+ if "3.25.0" == version:
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "cmake-3.25.0-linux-aarch64",
|
||||
+ urls = [
|
||||
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-linux-aarch64.tar.gz",
|
||||
+ ],
|
||||
+ sha256 = "27da36d6debe9b30f5c498554ae40cd621a55736f5f2ae2618ed95722a59965a",
|
||||
+ strip_prefix = "cmake-3.25.0-linux-aarch64",
|
||||
+ build_file_content = _CMAKE_BUILD_FILE.format(
|
||||
+ bin = "cmake",
|
||||
+ env = "{}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "cmake-3.25.0-linux-x86_64",
|
||||
+ urls = [
|
||||
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-linux-x86_64.tar.gz",
|
||||
+ ],
|
||||
+ sha256 = "ac634d6f0a81d7089adc7be5acff66a6bee3b08615f9a947858ce92a9ef59c8b",
|
||||
+ strip_prefix = "cmake-3.25.0-linux-x86_64",
|
||||
+ build_file_content = _CMAKE_BUILD_FILE.format(
|
||||
+ bin = "cmake",
|
||||
+ env = "{}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "cmake-3.25.0-macos-universal",
|
||||
+ urls = [
|
||||
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-macos-universal.tar.gz",
|
||||
+ ],
|
||||
+ sha256 = "c088e761534a2078cd9d0581d39f02d3f9ed05302e33135b55c6d619b263b4c3",
|
||||
+ strip_prefix = "cmake-3.25.0-macos-universal/CMake.app/Contents",
|
||||
+ build_file_content = _CMAKE_BUILD_FILE.format(
|
||||
+ bin = "cmake",
|
||||
+ env = "{}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "cmake-3.25.0-windows-i386",
|
||||
+ urls = [
|
||||
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-windows-i386.zip",
|
||||
+ ],
|
||||
+ sha256 = "ddd115257a19ff3dd18fc63f32a00ae742f8b62d2e39bc354629903512f99783",
|
||||
+ strip_prefix = "cmake-3.25.0-windows-i386",
|
||||
+ build_file_content = _CMAKE_BUILD_FILE.format(
|
||||
+ bin = "cmake.exe",
|
||||
+ env = "{}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "cmake-3.25.0-windows-x86_64",
|
||||
+ urls = [
|
||||
+ "https://github.com/Kitware/CMake/releases/download/v3.25.0/cmake-3.25.0-windows-x86_64.zip",
|
||||
+ ],
|
||||
+ sha256 = "b46030c10cab1170355952f9ac59f7e6dabc248070fc53f15dff11d4ed2910f8",
|
||||
+ strip_prefix = "cmake-3.25.0-windows-x86_64",
|
||||
+ build_file_content = _CMAKE_BUILD_FILE.format(
|
||||
+ bin = "cmake.exe",
|
||||
+ env = "{}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ # buildifier: leave-alone
|
||||
+ maybe(
|
||||
+ prebuilt_toolchains_repository,
|
||||
+ name = "cmake_3.25.0_toolchains",
|
||||
+ repos = {
|
||||
+ "cmake-3.25.0-linux-aarch64": [
|
||||
+ "@platforms//cpu:aarch64",
|
||||
+ "@platforms//os:linux",
|
||||
+ ],
|
||||
+ "cmake-3.25.0-linux-x86_64": [
|
||||
+ "@platforms//cpu:x86_64",
|
||||
+ "@platforms//os:linux",
|
||||
+ ],
|
||||
+ "cmake-3.25.0-macos-universal": [
|
||||
+ "@platforms//os:macos",
|
||||
+ ],
|
||||
+ "cmake-3.25.0-windows-i386": [
|
||||
+ "@platforms//cpu:x86_32",
|
||||
+ "@platforms//os:windows",
|
||||
+ ],
|
||||
+ "cmake-3.25.0-windows-x86_64": [
|
||||
+ "@platforms//cpu:x86_64",
|
||||
+ "@platforms//os:windows",
|
||||
+ ],
|
||||
+ },
|
||||
+ tool = "cmake",
|
||||
+ )
|
||||
+
|
||||
+ if register_toolchains:
|
||||
+ native.register_toolchains(
|
||||
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-linux-aarch64_toolchain",
|
||||
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-linux-x86_64_toolchain",
|
||||
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-macos-universal_toolchain",
|
||||
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-windows-i386_toolchain",
|
||||
+ "@cmake_3.25.0_toolchains//:cmake-3.25.0-windows-x86_64_toolchain",
|
||||
+ )
|
||||
+
|
||||
+ return
|
||||
if "3.23.2" == version:
|
||||
maybe(
|
||||
http_archive,
|
||||
@@ -4196,6 +4305,78 @@ def _cmake_toolchains(version, register_toolchains):
|
||||
fail("Unsupported version: " + str(version))
|
||||
|
||||
def _ninja_toolchains(version, register_toolchains):
|
||||
+ if "1.11.1" == version:
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "ninja_1.11.1_linux",
|
||||
+ urls = [
|
||||
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-linux.zip",
|
||||
+ ],
|
||||
+ sha256 = "b901ba96e486dce377f9a070ed4ef3f79deb45f4ffe2938f8e7ddc69cfb3df77",
|
||||
+ strip_prefix = "",
|
||||
+ build_file_content = _NINJA_BUILD_FILE.format(
|
||||
+ bin = "ninja",
|
||||
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "ninja_1.11.1_mac",
|
||||
+ urls = [
|
||||
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-mac.zip",
|
||||
+ ],
|
||||
+ sha256 = "482ecb23c59ae3d4f158029112de172dd96bb0e97549c4b1ca32d8fad11f873e",
|
||||
+ strip_prefix = "",
|
||||
+ build_file_content = _NINJA_BUILD_FILE.format(
|
||||
+ bin = "ninja",
|
||||
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ maybe(
|
||||
+ http_archive,
|
||||
+ name = "ninja_1.11.1_win",
|
||||
+ urls = [
|
||||
+ "https://github.com/ninja-build/ninja/releases/download/v1.11.1/ninja-win.zip",
|
||||
+ ],
|
||||
+ sha256 = "524b344a1a9a55005eaf868d991e090ab8ce07fa109f1820d40e74642e289abc",
|
||||
+ strip_prefix = "",
|
||||
+ build_file_content = _NINJA_BUILD_FILE.format(
|
||||
+ bin = "ninja.exe",
|
||||
+ env = "{\"NINJA\": \"$(execpath :ninja_bin)\"}",
|
||||
+ ),
|
||||
+ )
|
||||
+
|
||||
+ # buildifier: leave-alone
|
||||
+ maybe(
|
||||
+ prebuilt_toolchains_repository,
|
||||
+ name = "ninja_1.11.1_toolchains",
|
||||
+ repos = {
|
||||
+ "ninja_1.11.1_linux": [
|
||||
+ "@platforms//cpu:x86_64",
|
||||
+ "@platforms//os:linux",
|
||||
+ ],
|
||||
+ "ninja_1.11.1_mac": [
|
||||
+ "@platforms//cpu:x86_64",
|
||||
+ "@platforms//os:macos",
|
||||
+ ],
|
||||
+ "ninja_1.11.1_win": [
|
||||
+ "@platforms//cpu:x86_64",
|
||||
+ "@platforms//os:windows",
|
||||
+ ],
|
||||
+ },
|
||||
+ tool = "ninja",
|
||||
+ )
|
||||
+
|
||||
+ if register_toolchains:
|
||||
+ native.register_toolchains(
|
||||
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_linux_toolchain",
|
||||
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_mac_toolchain",
|
||||
+ "@ninja_1.11.1_toolchains//:ninja_1.11.1_win_toolchain",
|
||||
+ )
|
||||
+
|
||||
+ return
|
||||
if "1.11.0" == version:
|
||||
maybe(
|
||||
http_archive,
|
||||
diff --git a/toolchains/prebuilt_toolchains.py b/toolchains/prebuilt_toolchains.py
|
||||
index 5288b27..a193021 100755
|
||||
--- toolchains/prebuilt_toolchains.py
|
||||
+++ toolchains/prebuilt_toolchains.py
|
||||
@@ -10,6 +10,7 @@ CMAKE_SHA256_URL_TEMPLATE = "https://cmake.org/files/v{minor}/cmake-{full}-SHA-2
|
||||
CMAKE_URL_TEMPLATE = "https://github.com/Kitware/CMake/releases/download/v{full}/{file}"
|
||||
|
||||
CMAKE_VERSIONS = [
|
||||
+ "3.25.0",
|
||||
"3.23.2",
|
||||
"3.23.1",
|
||||
"3.22.4",
|
||||
@@ -116,6 +117,7 @@ NINJA_TARGETS = {
|
||||
}
|
||||
|
||||
NINJA_VERSIONS = (
|
||||
+ "1.11.1",
|
||||
"1.10.2",
|
||||
"1.10.1",
|
||||
"1.10.0",
|
||||
|
12
bazel/onnx.patch
Normal file
12
bazel/onnx.patch
Normal file
@ -0,0 +1,12 @@
|
||||
diff --git a/cmake/external/helper_functions.cmake b/cmake/external/helper_functions.cmake
|
||||
index 88b46890b7..d090499971 100644
|
||||
--- cmake/external/helper_functions.cmake
|
||||
+++ cmake/external/helper_functions.cmake
|
||||
@@ -117,7 +117,6 @@ macro(onnxruntime_fetchcontent_makeavailable)
|
||||
${__cmake_contentName}
|
||||
${__cmake_contentNameLower}
|
||||
)
|
||||
- find_package(${__cmake_contentName} ${__cmake_fpArgs})
|
||||
list(POP_BACK __cmake_fcCurrentNameStack
|
||||
__cmake_contentNameLower
|
||||
__cmake_contentName
|
363
bazel/onnx_ext.patch
Normal file
363
bazel/onnx_ext.patch
Normal file
@ -0,0 +1,363 @@
|
||||
diff --git a/operators/string_tensor.cc b/operators/string_tensor.cc
|
||||
index 3d49e64..84975f6 100644
|
||||
--- operators/string_tensor.cc
|
||||
+++ operators/string_tensor.cc
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc
|
||||
similarity index 99%
|
||||
rename from operators/string_utils.cc
|
||||
rename to operators/string_utils_onnx.cc
|
||||
index ecb6713..91dbe76 100644
|
||||
--- operators/string_utils.cc
|
||||
+++ operators/string_utils_onnx.cc
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "farmhash.h"
|
||||
#endif
|
||||
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
|
||||
std::vector<std::string_view> result;
|
||||
diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h
|
||||
similarity index 89%
|
||||
rename from operators/string_utils.h
|
||||
rename to operators/string_utils_onnx.h
|
||||
index 5653fbd..6556666 100644
|
||||
--- operators/string_utils.h
|
||||
+++ operators/string_utils_onnx.h
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
-#include "ocos.h"
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
@@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
-template <>
|
||||
-inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
|
||||
- MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
|
||||
-}
|
||||
-
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc
|
||||
index 324a774..00eac2b 100644
|
||||
--- operators/tokenizer/basic_tokenizer.cc
|
||||
+++ operators/tokenizer/basic_tokenizer.cc
|
||||
@@ -1,9 +1,8 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
-#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
@@ -81,52 +80,3 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||
push_current_token_and_clear();
|
||||
return result;
|
||||
}
|
||||
-
|
||||
-KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
- bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
- bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
-
|
||||
- tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
||||
-}
|
||||
-
|
||||
-void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- OrtTensorDimensions dimensions(ort_, input);
|
||||
- if (dimensions.size() != 1 && dimensions[0] != 1) {
|
||||
- ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
-
|
||||
- OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
- std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
-
|
||||
- FillTensorDataString(api_, ort_, context, result, output);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-};
|
||||
-
|
||||
-const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
||||
-
|
||||
-size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-};
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-};
|
||||
-
|
||||
-size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
|
||||
- return 1;
|
||||
-};
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-};
|
||||
diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp
|
||||
index 046499e..9fd6f1a 100644
|
||||
--- operators/tokenizer/basic_tokenizer.hpp
|
||||
+++ operators/tokenizer/basic_tokenizer.hpp
|
||||
@@ -3,8 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
-#include "ocos.h"
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "ustring.h"
|
||||
|
||||
class BasicTokenizer {
|
||||
@@ -19,19 +18,3 @@ class BasicTokenizer {
|
||||
bool tokenize_punctuation_;
|
||||
bool remove_control_chars_;
|
||||
};
|
||||
-
|
||||
-struct KernelBasicTokenizer : BaseKernel {
|
||||
- KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
- private:
|
||||
- std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
-};
|
||||
-
|
||||
-struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
||||
diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc
|
||||
index b860ba6..9f43c5e 100644
|
||||
--- operators/tokenizer/bert_tokenizer.cc
|
||||
+++ operators/tokenizer/bert_tokenizer.cc
|
||||
@@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
||||
|
||||
auto it = vocab_.find(utf8_token);
|
||||
if (it == vocab_.end()) {
|
||||
- ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||
+ std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token);
|
||||
+ return -1;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
@@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
||||
}
|
||||
}
|
||||
|
||||
-KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
- std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
- bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||
- std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
- std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||
- std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
|
||||
- std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
|
||||
- std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
|
||||
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
- std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
- std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
||||
- int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
|
||||
-
|
||||
- tokenizer_ = std::make_unique<BertTokenizer>(
|
||||
- vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
- ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
||||
- ustring(mask_token), tokenize_chinese_chars, strip_accents,
|
||||
- ustring(suffix_indicator), max_len, truncation_strategy_name);
|
||||
-}
|
||||
-
|
||||
-void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- if (input_data.size() != 1 && input_data.size() != 2) {
|
||||
- ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
- std::vector<int64_t> input_ids;
|
||||
- std::vector<int64_t> token_type_ids;
|
||||
-
|
||||
- if (input_data.size() == 1) {
|
||||
- std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
|
||||
- tokenizer_->Truncate(encoded);
|
||||
- input_ids = tokenizer_->AddSpecialToken(encoded);
|
||||
- token_type_ids = tokenizer_->GenerateTypeId(encoded);
|
||||
- } else {
|
||||
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
||||
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
||||
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
||||
- input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
||||
- token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
||||
- }
|
||||
-
|
||||
- std::vector<int64_t> attention_mask(input_ids.size(), 1);
|
||||
-
|
||||
- std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
|
||||
-
|
||||
- SetOutput(context, 0, output_dim, input_ids);
|
||||
- SetOutput(context, 1, output_dim, token_type_ids);
|
||||
- SetOutput(context, 2, output_dim, attention_mask);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-}
|
||||
-
|
||||
-const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
|
||||
-
|
||||
-size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-}
|
||||
-
|
||||
-size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
||||
- return 3;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
-}
|
||||
-
|
||||
-KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||
-
|
||||
-void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- if (input_data.size() != 2) {
|
||||
- ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
-
|
||||
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
||||
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
||||
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
||||
- std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
||||
- std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
||||
- std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
|
||||
-
|
||||
- const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
|
||||
- const std::vector<int64_t> inner_dims{1LL};
|
||||
- for (int32_t i = 0; i < 3; ++i) {
|
||||
- OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
|
||||
- OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
|
||||
- ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
|
||||
- ort_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
- }
|
||||
-
|
||||
- SetOutput(context, 0, outer_dims, input_ids);
|
||||
- SetOutput(context, 1, outer_dims, attention_mask);
|
||||
- SetOutput(context, 2, outer_dims, token_type_ids);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-}
|
||||
-
|
||||
-const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
|
||||
-
|
||||
-size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-}
|
||||
-
|
||||
-size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
|
||||
- return 3;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
-}
|
||||
diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp
|
||||
index 6dfcd84..10565e4 100644
|
||||
--- operators/tokenizer/bert_tokenizer.hpp
|
||||
+++ operators/tokenizer/bert_tokenizer.hpp
|
||||
@@ -3,12 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
+#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
-#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
-#include "string_utils.h"
|
||||
-#include "string_tensor.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
|
||||
class BertTokenizerVocab final {
|
||||
@@ -89,33 +88,4 @@ class BertTokenizer final {
|
||||
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
||||
};
|
||||
|
||||
-struct KernelBertTokenizer : BaseKernel {
|
||||
- KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
|
||||
- protected:
|
||||
- std::unique_ptr<BertTokenizer> tokenizer_;
|
||||
-};
|
||||
-
|
||||
-struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
||||
-
|
||||
-struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||
- KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
-};
|
||||
-
|
||||
-struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
12
bazel/onnxruntime.BUILD
Normal file
12
bazel/onnxruntime.BUILD
Normal file
@ -0,0 +1,12 @@
|
||||
filegroup(
|
||||
name = "all_srcs",
|
||||
srcs = glob(["**"]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hdrs",
|
||||
hdrs = glob(["include/onnxruntime/**/*.h"]),
|
||||
strip_include_prefix = "include/onnxruntime",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
25
bazel/onnxruntime_extensions.BUILD
Normal file
25
bazel/onnxruntime_extensions.BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
cc_library(
|
||||
name = "headers",
|
||||
hdrs = glob(["includes/**"]),
|
||||
strip_include_prefix = "includes",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@github_nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operators",
|
||||
hdrs = glob(["operators/base64.h",
|
||||
"operators/string_utils*",
|
||||
"operators/ustring.h",
|
||||
"operators/tokenizer/bert_tokenizer.hpp",
|
||||
"operators/tokenizer/basic_tokenizer.hpp"]),
|
||||
srcs = glob(["operators/base64.cc",
|
||||
"operators/string_utils*",
|
||||
"operators/ustring.cc",
|
||||
"operators/tokenizer/bert_tokenizer.cc",
|
||||
"operators/tokenizer/basic_tokenizer.cc"]),
|
||||
strip_include_prefix = "operators",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
2
build.sh
2
build.sh
@ -22,7 +22,7 @@ if [[ "$@" == *"--depclean"* ]]; then
|
||||
fi
|
||||
|
||||
cmake -DTYPESENSE_VERSION=$TYPESENSE_VERSION -DCMAKE_BUILD_TYPE=Release -H$PROJECT_DIR -B$PROJECT_DIR/$BUILD_DIR
|
||||
make typesense-server typesense-test -C $PROJECT_DIR/$BUILD_DIR
|
||||
cmake --build $PROJECT_DIR/$BUILD_DIR --target typesense-server --target typesense-test
|
||||
|
||||
if [[ "$@" == *"--package-binary"* ]]; then
|
||||
OS_FAMILY=$(echo `uname` | awk '{print tolower($0)}')
|
||||
|
12
cmake/onnx.patch
Normal file
12
cmake/onnx.patch
Normal file
@ -0,0 +1,12 @@
|
||||
diff --git a/cmake/external/helper_functions.cmake b/cmake/external/helper_functions.cmake
|
||||
index 88b46890b7..d090499971 100644
|
||||
--- a/cmake/external/helper_functions.cmake
|
||||
+++ b/cmake/external/helper_functions.cmake
|
||||
@@ -117,7 +117,6 @@ macro(onnxruntime_fetchcontent_makeavailable)
|
||||
${__cmake_contentName}
|
||||
${__cmake_contentNameLower}
|
||||
)
|
||||
- find_package(${__cmake_contentName} ${__cmake_fpArgs})
|
||||
list(POP_BACK __cmake_fcCurrentNameStack
|
||||
__cmake_contentNameLower
|
||||
__cmake_contentName
|
351
cmake/onnx_ext.patch
Normal file
351
cmake/onnx_ext.patch
Normal file
@ -0,0 +1,351 @@
|
||||
diff --git a/operators/string_utils.cc b/operators/string_utils_onnx.cc
|
||||
similarity index 99%
|
||||
rename from operators/string_utils.cc
|
||||
rename to operators/string_utils_onnx.cc
|
||||
index ecb6713..91dbe76 100644
|
||||
--- a/operators/string_utils.cc
|
||||
+++ b/operators/string_utils_onnx.cc
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "farmhash.h"
|
||||
#endif
|
||||
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
|
||||
std::vector<std::string_view> result;
|
||||
diff --git a/operators/string_utils.h b/operators/string_utils_onnx.h
|
||||
similarity index 89%
|
||||
rename from operators/string_utils.h
|
||||
rename to operators/string_utils_onnx.h
|
||||
index 5653fbd..6556666 100644
|
||||
--- a/operators/string_utils.h
|
||||
+++ b/operators/string_utils_onnx.h
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
-#include "ocos.h"
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
@@ -23,11 +22,6 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
-template <>
|
||||
-inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
|
||||
- MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
|
||||
-}
|
||||
-
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc
|
||||
index 324a774..00eac2b 100644
|
||||
--- a/operators/tokenizer/basic_tokenizer.cc
|
||||
+++ b/operators/tokenizer/basic_tokenizer.cc
|
||||
@@ -1,9 +1,8 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
-#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
@@ -81,52 +80,3 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||
push_current_token_and_clear();
|
||||
return result;
|
||||
}
|
||||
-
|
||||
-KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
- bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
- bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
-
|
||||
- tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
||||
-}
|
||||
-
|
||||
-void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- OrtTensorDimensions dimensions(ort_, input);
|
||||
- if (dimensions.size() != 1 && dimensions[0] != 1) {
|
||||
- ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
-
|
||||
- OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
- std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
-
|
||||
- FillTensorDataString(api_, ort_, context, result, output);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-};
|
||||
-
|
||||
-const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
||||
-
|
||||
-size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-};
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-};
|
||||
-
|
||||
-size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
|
||||
- return 1;
|
||||
-};
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-};
|
||||
diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp
|
||||
index 046499e..9fd6f1a 100644
|
||||
--- a/operators/tokenizer/basic_tokenizer.hpp
|
||||
+++ b/operators/tokenizer/basic_tokenizer.hpp
|
||||
@@ -3,8 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
-#include "ocos.h"
|
||||
-#include "string_utils.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "ustring.h"
|
||||
|
||||
class BasicTokenizer {
|
||||
@@ -19,19 +18,3 @@ class BasicTokenizer {
|
||||
bool tokenize_punctuation_;
|
||||
bool remove_control_chars_;
|
||||
};
|
||||
-
|
||||
-struct KernelBasicTokenizer : BaseKernel {
|
||||
- KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
- private:
|
||||
- std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
-};
|
||||
-
|
||||
-struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
||||
diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc
|
||||
index b860ba6..9f43c5e 100644
|
||||
--- a/operators/tokenizer/bert_tokenizer.cc
|
||||
+++ b/operators/tokenizer/bert_tokenizer.cc
|
||||
@@ -33,7 +33,8 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
||||
|
||||
auto it = vocab_.find(utf8_token);
|
||||
if (it == vocab_.end()) {
|
||||
- ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||
+ std::cout << "[BertTokenizerVocab]: can not find tokens: " + std::string(token);
|
||||
+ return -1;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
@@ -276,138 +277,3 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
||||
}
|
||||
}
|
||||
|
||||
-KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
- std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
- bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
- bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||
- std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
- std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||
- std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
|
||||
- std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
|
||||
- std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
|
||||
- bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
- bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
- std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
- std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
||||
- int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
|
||||
-
|
||||
- tokenizer_ = std::make_unique<BertTokenizer>(
|
||||
- vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
- ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
||||
- ustring(mask_token), tokenize_chinese_chars, strip_accents,
|
||||
- ustring(suffix_indicator), max_len, truncation_strategy_name);
|
||||
-}
|
||||
-
|
||||
-void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- if (input_data.size() != 1 && input_data.size() != 2) {
|
||||
- ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
- std::vector<int64_t> input_ids;
|
||||
- std::vector<int64_t> token_type_ids;
|
||||
-
|
||||
- if (input_data.size() == 1) {
|
||||
- std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
|
||||
- tokenizer_->Truncate(encoded);
|
||||
- input_ids = tokenizer_->AddSpecialToken(encoded);
|
||||
- token_type_ids = tokenizer_->GenerateTypeId(encoded);
|
||||
- } else {
|
||||
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
||||
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
||||
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
||||
- input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
||||
- token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
||||
- }
|
||||
-
|
||||
- std::vector<int64_t> attention_mask(input_ids.size(), 1);
|
||||
-
|
||||
- std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
|
||||
-
|
||||
- SetOutput(context, 0, output_dim, input_ids);
|
||||
- SetOutput(context, 1, output_dim, token_type_ids);
|
||||
- SetOutput(context, 2, output_dim, attention_mask);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-}
|
||||
-
|
||||
-const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
|
||||
-
|
||||
-size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-}
|
||||
-
|
||||
-size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
||||
- return 3;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
-}
|
||||
-
|
||||
-KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||
-
|
||||
-void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
- // Setup inputs
|
||||
- const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
|
||||
- std::vector<std::string> input_data;
|
||||
- GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
-
|
||||
- if (input_data.size() != 2) {
|
||||
- ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
||||
- }
|
||||
-
|
||||
- std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
- std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
|
||||
- std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
|
||||
- std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
|
||||
- std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
|
||||
- std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
|
||||
- std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
|
||||
-
|
||||
- const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
|
||||
- const std::vector<int64_t> inner_dims{1LL};
|
||||
- for (int32_t i = 0; i < 3; ++i) {
|
||||
- OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
|
||||
- OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
|
||||
- ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
|
||||
- ort_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
- }
|
||||
-
|
||||
- SetOutput(context, 0, outer_dims, input_ids);
|
||||
- SetOutput(context, 1, outer_dims, attention_mask);
|
||||
- SetOutput(context, 2, outer_dims, token_type_ids);
|
||||
-}
|
||||
-
|
||||
-void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
- return CreateKernelImpl(api, info);
|
||||
-}
|
||||
-
|
||||
-const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
|
||||
-
|
||||
-size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
|
||||
- return 1;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
-}
|
||||
-
|
||||
-size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
|
||||
- return 3;
|
||||
-}
|
||||
-
|
||||
-ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
-}
|
||||
diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp
|
||||
index 6dfcd84..10565e4 100644
|
||||
--- a/operators/tokenizer/bert_tokenizer.hpp
|
||||
+++ b/operators/tokenizer/bert_tokenizer.hpp
|
||||
@@ -3,12 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
+#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
-#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
-#include "string_utils.h"
|
||||
-#include "string_tensor.h"
|
||||
+#include "string_utils_onnx.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
|
||||
class BertTokenizerVocab final {
|
||||
@@ -89,33 +88,4 @@ class BertTokenizer final {
|
||||
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
||||
};
|
||||
|
||||
-struct KernelBertTokenizer : BaseKernel {
|
||||
- KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
|
||||
- protected:
|
||||
- std::unique_ptr<BertTokenizer> tokenizer_;
|
||||
-};
|
||||
-
|
||||
-struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
||||
-
|
||||
-struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||
- KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
- void Compute(OrtKernelContext* context);
|
||||
-};
|
||||
-
|
||||
-struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||
- void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
- const char* GetName() const;
|
||||
- size_t GetInputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
- size_t GetOutputTypeCount() const;
|
||||
- ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
-};
|
19
cmake/onnxruntime.cmake
Normal file
19
cmake/onnxruntime.cmake
Normal file
@ -0,0 +1,19 @@
|
||||
project(onnxruntme)
|
||||
set(ONNX_NAME onnxruntime)
|
||||
include(ExternalProject)
|
||||
|
||||
if(NOT EXISTS ${DEP_ROOT_DIR}/${ONNX_NAME})
|
||||
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_NAME})
|
||||
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_NAME}-build)
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
onnxruntime
|
||||
GIT_REPOSITORY https://github.com/microsoft/onnxruntime
|
||||
GIT_TAG origin/rel-1.14.0
|
||||
SOURCE_DIR ${DEP_ROOT_DIR}/${ONNX_NAME}
|
||||
PATCH_COMMAND cd ${DEP_ROOT_DIR}/${ONNX_NAME} && git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx.patch -R --check
|
||||
BINARY_DIR ${DEP_ROOT_DIR}/${ONNX_NAME}-build
|
||||
CONFIGURE_COMMAND ${CMAKE_COMMAND} ${DEP_ROOT_DIR}/${ONNX_NAME}/cmake -B${DEP_ROOT_DIR}/${ONNX_NAME}-build -Donnxruntime_RUN_ONNX_TESTS=OFF -Donnxruntime_GENERATE_TEST_REPORTS=ON -Donnxruntime_USE_MIMALLOC=OFF -Donnxruntime_ENABLE_PYTHON=OFF -Donnxruntime_BUILD_CSHARP=OFF -Donnxruntime_BUILD_JAVA=OFF -Donnxruntime_BUILD_NODEJS=OFF -Donnxruntime_BUILD_OBJC=OFF -Donnxruntime_BUILD_SHARED_LIB=OFF -Donnxruntime_BUILD_APPLE_FRAMEWORK=OFF -Donnxruntime_USE_DNNL=OFF -Donnxruntime_USE_NNAPI_BUILTIN=OFF -Donnxruntime_USE_RKNPU=OFF -Donnxruntime_USE_LLVM=OFF -Donnxruntime_ENABLE_MICROSOFT_INTERNAL=OFF -Donnxruntime_USE_VITISAI=OFF -Donnxruntime_USE_TENSORRT=OFF -Donnxruntime_SKIP_AND_PERFORM_FILTERED_TENSORRT_TESTS=ON -Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=OFF -Donnxruntime_TENSORRT_PLACEHOLDER_BUILDER=OFF -Donnxruntime_USE_TVM=OFF -Donnxruntime_TVM_CUDA_RUNTIME=OFF -Donnxruntime_TVM_USE_HASH=OFF -Donnxruntime_USE_MIGRAPHX=OFF -Donnxruntime_CROSS_COMPILING=OFF -Donnxruntime_DISABLE_CONTRIB_OPS=OFF -Donnxruntime_DISABLE_ML_OPS=OFF -Donnxruntime_DISABLE_RTTI=OFF -Donnxruntime_DISABLE_EXCEPTIONS=OFF -Donnxruntime_MINIMAL_BUILD=OFF -Donnxruntime_EXTENDED_MINIMAL_BUILD=OFF -Donnxruntime_MINIMAL_BUILD_CUSTOM_OPS=OFF -Donnxruntime_REDUCED_OPS_BUILD=OFF -Donnxruntime_ENABLE_LANGUAGE_INTEROP_OPS=OFF -Donnxruntime_USE_DML=OFF -Donnxruntime_USE_WINML=OFF -Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=OFF -Donnxruntime_USE_TELEMETRY=OFF -Donnxruntime_ENABLE_LTO=OFF -Donnxruntime_USE_ACL=OFF -Donnxruntime_USE_ACL_1902=OFF -Donnxruntime_USE_ACL_1905=OFF -Donnxruntime_USE_ACL_1908=OFF -Donnxruntime_USE_ACL_2002=OFF -Donnxruntime_USE_ARMNN=OFF -Donnxruntime_ARMNN_RELU_USE_CPU=ON -Donnxruntime_ARMNN_BN_USE_CPU=ON -Donnxruntime_ENABLE_NVTX_PROFILE=OFF -Donnxruntime_ENABLE_TRAINING=OFF -Donnxruntime_ENABLE_TRAINING_OPS=OFF -Donnxruntime_ENABLE_TRAINING_APIS=OFF -Donnxruntime_ENABLE_CPU_FP16_OPS=OFF -Donnxruntime_USE_NCCL=OFF -Donnxruntime_BUILD_BENCHMARKS=OFF -Donnxruntime_USE_ROCM=OFF -Donnxruntime_GCOV_COVERAGE=OFF -Donnxruntime_USE_MPI=ON -Donnxruntime_ENABLE_MEMORY_PROFILE=OFF -Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=OFF -Donnxruntime_BUILD_WEBASSEMBLY=OFF -Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON -Donnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING=ON -Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=OFF -Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=OFF -Donnxruntime_ENABLE_EAGER_MODE=OFF -Donnxruntime_ENABLE_LAZY_TENSOR=OFF -Donnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS=OFF -Donnxruntime_ENABLE_CUDA_PROFILING=OFF -Donnxruntime_ENABLE_ROCM_PROFILING=OFF -Donnxruntime_USE_XNNPACK=OFF -Donnxruntime_USE_CANN=OFF -DCMAKE_TLS_VERIFY=ON -DFETCHCONTENT_QUIET=OFF -Donnxruntime_PYBIND_EXPORT_OPSCHEMA=OFF -Donnxruntime_ENABLE_MEMLEAK_CHECKER=OFF -DCMAKE_BUILD_TYPE=Release
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build ${DEP_ROOT_DIR}/${ONNX_NAME}-build --config Release -- -j8
|
||||
)
|
21
cmake/onnxruntime_ext.cmake
Normal file
21
cmake/onnxruntime_ext.cmake
Normal file
@ -0,0 +1,21 @@
|
||||
project(onnxruntme_ext)
|
||||
set(ONNX_EXT_NAME onnxruntime_ext)
|
||||
include(ExternalProject)
|
||||
|
||||
if(NOT EXISTS ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
|
||||
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
|
||||
else()
|
||||
file(REMOVE_RECURSE ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
|
||||
file(MAKE_DIRECTORY ${DEP_ROOT_DIR}/${ONNX_EXT_NAME})
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
onnxruntime_ext
|
||||
GIT_REPOSITORY https://github.com/microsoft/onnxruntime-extensions
|
||||
GIT_TAG origin/rel-0.6.0
|
||||
SOURCE_DIR ${DEP_ROOT_DIR}/${ONNX_EXT_NAME}
|
||||
PATCH_COMMAND cd ${DEP_ROOT_DIR}/${ONNX_EXT_NAME} && git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx_ext.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/onnx_ext.patch -R --check && cd operators && mkdir -p src_dir && mkdir -p src_dir/tokenizer && cp base64.h base64.cc string_utils_onnx.h string_utils_onnx.cc ustring.h ustring.cc src_dir && cp tokenizer/bert_tokenizer.hpp tokenizer/bert_tokenizer.cc tokenizer/basic_tokenizer.hpp tokenizer/basic_tokenizer.cc src_dir/tokenizer
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
)
|
@ -265,6 +265,8 @@ private:
|
||||
const spp::sparse_hash_set<std::string>& exclude_fields,
|
||||
tsl::htrie_set<char>& include_fields_full,
|
||||
tsl::htrie_set<char>& exclude_fields_full) const;
|
||||
|
||||
|
||||
|
||||
public:
|
||||
|
||||
@ -345,6 +347,8 @@ public:
|
||||
const DIRTY_VALUES dirty_values,
|
||||
const std::string& id="");
|
||||
|
||||
Option<bool> embed_fields(nlohmann::json& document);
|
||||
|
||||
static uint32_t get_seq_id_from_key(const std::string & key);
|
||||
|
||||
Option<bool> get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const;
|
||||
@ -402,7 +406,7 @@ public:
|
||||
tsl::htrie_set<char>& include_fields_full,
|
||||
tsl::htrie_set<char>& exclude_fields_full) const;
|
||||
|
||||
Option<nlohmann::json> search(const std::string & query, const std::vector<std::string> & search_fields,
|
||||
Option<nlohmann::json> search(std::string query, const std::vector<std::string> & search_fields,
|
||||
const std::string & filter_query, const std::vector<std::string> & facet_fields,
|
||||
const std::vector<sort_by> & sort_fields, const std::vector<uint32_t>& num_typos,
|
||||
size_t per_page = 10, size_t page = 1,
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <sparsepp.h>
|
||||
#include <tsl/htrie_map.h>
|
||||
#include "json.hpp"
|
||||
#include "text_embedder_manager.h"
|
||||
|
||||
namespace field_types {
|
||||
// first field value indexed will determine the type
|
||||
@ -48,6 +49,8 @@ namespace fields {
|
||||
static const std::string num_dim = "num_dim";
|
||||
static const std::string vec_dist = "vec_dist";
|
||||
static const std::string reference = "reference";
|
||||
static const std::string create_from = "create_from";
|
||||
static const std::string model_name = "model_name";
|
||||
}
|
||||
|
||||
enum vector_distance_type_t {
|
||||
@ -73,6 +76,8 @@ struct field {
|
||||
int nested_array;
|
||||
|
||||
size_t num_dim;
|
||||
std::vector<std::string> create_from;
|
||||
std::string model_name;
|
||||
vector_distance_type_t vec_dist;
|
||||
|
||||
static constexpr int VAL_UNKNOWN = 2;
|
||||
@ -83,9 +88,9 @@ struct field {
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "") :
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &create_from = {}, const std::string& model_name = "") :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), create_from(create_from), model_name(model_name) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
@ -313,7 +318,12 @@ struct field {
|
||||
if (!field.reference.empty()) {
|
||||
field_val[fields::reference] = field.reference;
|
||||
}
|
||||
|
||||
if(!field.create_from.empty()) {
|
||||
field_val[fields::create_from] = field.create_from;
|
||||
if(!field.model_name.empty()) {
|
||||
field_val[fields::model_name] = field.model_name;
|
||||
}
|
||||
}
|
||||
fields_json.push_back(field_val);
|
||||
|
||||
if(!field.has_valid_type()) {
|
||||
@ -409,6 +419,59 @@ struct field {
|
||||
size_t num_auto_detect_fields = 0;
|
||||
|
||||
for(nlohmann::json & field_json: fields_json) {
|
||||
|
||||
if(field_json.count(fields::create_from) != 0) {
|
||||
if(TextEmbedderManager::model_dir.empty()) {
|
||||
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
|
||||
}
|
||||
|
||||
if(!field_json[fields::create_from].is_array()) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` must be an array.");
|
||||
}
|
||||
|
||||
if(field_json[fields::create_from].empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` must have at least one element.");
|
||||
}
|
||||
|
||||
for(auto& create_from_field : field_json[fields::create_from]) {
|
||||
if(!create_from_field.is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` must be an array of strings.");
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` is only allowed on a float array field.");
|
||||
}
|
||||
|
||||
|
||||
for(auto& create_from_field : field_json[fields::create_from]) {
|
||||
bool flag = false;
|
||||
for(const auto& field : fields_json) {
|
||||
if(field[fields::name] == create_from_field) {
|
||||
if(field[fields::type] != field_types::STRING) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
for(const auto& field : the_fields) {
|
||||
if(field.name == create_from_field) {
|
||||
if(field.type != field_types::STRING) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with array of string fields.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto op = json_field_to_field(enable_nested_fields,
|
||||
field_json, the_fields, fallback_field_type, num_auto_detect_fields);
|
||||
if(!op.ok()) {
|
||||
|
@ -24,6 +24,8 @@ private:
|
||||
|
||||
static size_t curl_write_async_done(void* context, curl_socket_t item);
|
||||
|
||||
static size_t curl_write_download(void *ptr, size_t size, size_t nmemb, FILE *stream);
|
||||
|
||||
static CURL* init_curl(const std::string& url, std::string& response);
|
||||
|
||||
static CURL* init_curl_async(const std::string& url, deferred_req_res_t* req_res, curl_slist*& chunk);
|
||||
@ -43,6 +45,8 @@ public:
|
||||
|
||||
void init(const std::string & api_key);
|
||||
|
||||
static long download_file(const std::string& url, const std::string& file_path);
|
||||
|
||||
static long get_response(const std::string& url, std::string& response,
|
||||
std::map<std::string, std::string>& res_headers, long timeout_ms=4000);
|
||||
|
||||
|
28
include/text_embedder.h
Normal file
28
include/text_embedder.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <core/session/onnxruntime_cxx_api.h>
|
||||
#include <tokenizer/bert_tokenizer.hpp>
|
||||
#include <vector>
|
||||
|
||||
struct encoded_input_t {
|
||||
std::vector<int64_t> input_ids;
|
||||
std::vector<int64_t> token_type_ids;
|
||||
std::vector<int64_t> attention_mask;
|
||||
};
|
||||
|
||||
|
||||
class TextEmbedder {
|
||||
public:
|
||||
TextEmbedder(const std::string& model_path);
|
||||
~TextEmbedder();
|
||||
std::vector<float> Embed(const std::string& text);
|
||||
|
||||
static bool is_model_valid(const std::string& model_path, unsigned int& num_dims);
|
||||
private:
|
||||
Ort::Session* session_;
|
||||
Ort::Env env_;
|
||||
encoded_input_t Encode(const std::string& text);
|
||||
BertTokenizer* tokenizer_;
|
||||
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
|
||||
std::string output_tensor_name;
|
||||
};
|
88
include/text_embedder_manager.h
Normal file
88
include/text_embedder_manager.h
Normal file
@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "text_embedder.h"
|
||||
|
||||
// singleton class
|
||||
class TextEmbedderManager {
|
||||
public:
|
||||
static TextEmbedderManager& get_instance() {
|
||||
static TextEmbedderManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
TextEmbedderManager(TextEmbedderManager&&) = delete;
|
||||
TextEmbedderManager& operator=(TextEmbedderManager&&) = delete;
|
||||
TextEmbedderManager(const TextEmbedderManager&) = delete;
|
||||
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
|
||||
|
||||
TextEmbedder* get_text_embedder(const std::string& model_path) {
|
||||
if (text_embedders.find(model_path) == text_embedders.end()) {
|
||||
text_embedders[model_path] = new TextEmbedder(model_path);
|
||||
}
|
||||
return text_embedders[model_path];
|
||||
}
|
||||
|
||||
void delete_text_embedder(const std::string& model_path) {
|
||||
if (text_embedders.find(model_path) != text_embedders.end()) {
|
||||
delete text_embedders[model_path];
|
||||
text_embedders.erase(model_path);
|
||||
}
|
||||
}
|
||||
|
||||
void delete_all_text_embedders() {
|
||||
for (auto& text_embedder : text_embedders) {
|
||||
delete text_embedder.second;
|
||||
}
|
||||
text_embedders.clear();
|
||||
}
|
||||
|
||||
static void set_model_dir(const std::string& dir) {
|
||||
model_dir = dir;
|
||||
}
|
||||
|
||||
static const std::string& get_model_dir() {
|
||||
return model_dir;
|
||||
}
|
||||
|
||||
~TextEmbedderManager() {
|
||||
delete_all_text_embedders();
|
||||
}
|
||||
|
||||
static constexpr char* DEFAULT_MODEL_URL = "https://huggingface.co/typesense/models/resolve/main/e5-small/model.onnx";
|
||||
static constexpr char* DEFAULT_MODEL_NAME = "ts-e5-small";
|
||||
static constexpr char* DEFAULT_VOCAB_URL = "https://huggingface.co/typesense/models/resolve/main/e5-small/vocab.txt";
|
||||
static constexpr char* DEFAULT_VOCAB_NAME = "vocab.txt";
|
||||
inline static std::string model_dir = "";
|
||||
inline static const std::string get_absolute_model_path(const std::string& model_name) {
|
||||
if(model_dir.back() != '/') {
|
||||
if(model_name.front() != '/') {
|
||||
return model_dir + "/" + model_name + ".onnx";
|
||||
} else {
|
||||
return model_dir + model_name + ".onnx";
|
||||
}
|
||||
} else {
|
||||
if(model_name.front() != '/') {
|
||||
return model_dir + model_name + ".onnx";
|
||||
} else {
|
||||
return model_dir + "/" + model_name + ".onnx";
|
||||
}
|
||||
}
|
||||
};
|
||||
inline static const std::string get_absolute_vocab_path() {
|
||||
if(model_dir.back() != '/') {
|
||||
return model_dir + "/" + TextEmbedderManager::DEFAULT_VOCAB_NAME;
|
||||
} else {
|
||||
return model_dir + TextEmbedderManager::DEFAULT_VOCAB_NAME;
|
||||
}
|
||||
}
|
||||
private:
|
||||
TextEmbedderManager() = default;
|
||||
std::unordered_map<std::string, TextEmbedder*> text_embedders;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@ class Config {
|
||||
private:
|
||||
std::string data_dir;
|
||||
std::string log_dir;
|
||||
std::string model_dir;
|
||||
|
||||
std::string api_key;
|
||||
|
||||
@ -108,6 +109,10 @@ public:
|
||||
this->log_dir = log_dir;
|
||||
}
|
||||
|
||||
void set_model_dir(const std::string & model_dir) {
|
||||
this->model_dir = model_dir;
|
||||
}
|
||||
|
||||
void set_api_key(const std::string & api_key) {
|
||||
this->api_key = api_key;
|
||||
}
|
||||
@ -171,6 +176,10 @@ public:
|
||||
return this->log_dir;
|
||||
}
|
||||
|
||||
std::string get_model_dir() const {
|
||||
return this->model_dir;
|
||||
}
|
||||
|
||||
std::string get_api_key() const {
|
||||
return this->api_key;
|
||||
}
|
||||
@ -565,6 +574,10 @@ public:
|
||||
auto skip_writes_str = reader.Get("server", "skip-writes", "false");
|
||||
this->skip_writes = (skip_writes_str == "true");
|
||||
}
|
||||
|
||||
if(reader.Exists("server", "model-dir")) {
|
||||
this->model_dir = reader.Get("server", "model-dir", "");
|
||||
}
|
||||
}
|
||||
|
||||
void load_config_cmd_args(cmdline::parser & options) {
|
||||
@ -580,6 +593,11 @@ public:
|
||||
this->api_key = options.get<std::string>("api-key");
|
||||
}
|
||||
|
||||
if(options.exist("model-dir")) {
|
||||
LOG(INFO) << "model-dir found";
|
||||
this->model_dir = options.get<std::string>("model-dir");
|
||||
}
|
||||
|
||||
// @deprecated
|
||||
if(options.exist("search-only-api-key")) {
|
||||
this->search_only_api_key = options.get<std::string>("search-only-api-key");
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "logger.h"
|
||||
#include "thread_local_vars.h"
|
||||
#include "vector_query_ops.h"
|
||||
#include "text_embedder_manager.h"
|
||||
|
||||
const std::string override_t::MATCH_EXACT = "exact";
|
||||
const std::string override_t::MATCH_CONTAINS = "contains";
|
||||
@ -71,6 +72,10 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
|
||||
const std::string& id) {
|
||||
try {
|
||||
document = nlohmann::json::parse(json_str);
|
||||
auto embed_res = embed_fields(document);
|
||||
if (!embed_res.ok()) {
|
||||
return Option<doc_seq_id_t>(400, embed_res.error());
|
||||
}
|
||||
} catch(const std::exception& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
return Option<doc_seq_id_t>(400, std::string("Bad JSON: ") + e.what());
|
||||
@ -224,7 +229,15 @@ nlohmann::json Collection::get_summary_json() const {
|
||||
field_json[fields::sort] = coll_field.sort;
|
||||
field_json[fields::infix] = coll_field.infix;
|
||||
field_json[fields::locale] = coll_field.locale;
|
||||
|
||||
if(coll_field.create_from.size() > 0) {
|
||||
field_json[fields::create_from] = coll_field.create_from;
|
||||
}
|
||||
|
||||
if(coll_field.model_name.size() > 0) {
|
||||
field_json[fields::model_name] = coll_field.model_name;
|
||||
}
|
||||
|
||||
if(coll_field.num_dim > 0) {
|
||||
field_json[fields::num_dim] = coll_field.num_dim;
|
||||
}
|
||||
@ -281,6 +294,7 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
|
||||
const std::string & json_line = json_lines[i];
|
||||
Option<doc_seq_id_t> doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id);
|
||||
|
||||
|
||||
const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0;
|
||||
index_record record(i, seq_id, document, operation, dirty_values);
|
||||
|
||||
@ -963,8 +977,9 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
|
||||
for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) {
|
||||
bool exact_key_match = (kv.key().size() == field_name.size());
|
||||
bool exact_primitive_match = exact_key_match && !kv.value().is_object();
|
||||
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().create_from.size() > 0;
|
||||
|
||||
if(extract_only_string_fields && !kv.value().is_string()) {
|
||||
if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) {
|
||||
if(exact_primitive_match && !is_wildcard) {
|
||||
// upstream needs to be returned an error
|
||||
return Option<bool>(400, "Field `" + field_name + "` should be a string or a string array.");
|
||||
@ -973,7 +988,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
|
||||
continue;
|
||||
}
|
||||
|
||||
if (exact_primitive_match || is_wildcard ||
|
||||
if (exact_primitive_match || is_wildcard || text_embedding ||
|
||||
// field_name prefix must be followed by a "." to indicate an object search
|
||||
(enable_nested_fields && kv.key().size() > field_name.size() && kv.key()[field_name.size()] == '.')) {
|
||||
processed_search_fields.push_back(kv.key());
|
||||
@ -992,7 +1007,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const std::vector<std::string>& raw_search_fields,
|
||||
const std::string & filter_query, const std::vector<std::string>& facet_fields,
|
||||
const std::vector<sort_by> & sort_fields, const std::vector<uint32_t>& num_typos,
|
||||
@ -1113,10 +1128,12 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// validate search fields
|
||||
std::vector<std::string> processed_search_fields;
|
||||
std::vector<uint32_t> query_by_weights;
|
||||
|
||||
bool has_embedding_query = false;
|
||||
for(size_t i = 0; i < raw_search_fields.size(); i++) {
|
||||
const std::string& field_name = raw_search_fields[i];
|
||||
if(field_name == "id") {
|
||||
@ -1132,6 +1149,30 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
}
|
||||
|
||||
for(const auto& expanded_search_field: expanded_search_fields) {
|
||||
auto search_field = search_schema.at(expanded_search_field);
|
||||
|
||||
if(search_field.num_dim > 0) {
|
||||
if(has_embedding_query) {
|
||||
std::string error = "Only one embedding field is allowed in the query.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
if(TextEmbedderManager::model_dir.empty()) {
|
||||
std::string error = "Text embedding is not enabled. Please set `model-dir` at startup.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(search_field.model_name.size() > 0 ? search_field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
|
||||
|
||||
std::vector<float> embedding = embedder->Embed(raw_query);
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
has_embedding_query = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
processed_search_fields.push_back(expanded_search_field);
|
||||
if(!raw_query_by_weights.empty()) {
|
||||
query_by_weights.push_back(raw_query_by_weights[i]);
|
||||
@ -1139,14 +1180,19 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
}
|
||||
}
|
||||
|
||||
std::string real_raw_query = raw_query;
|
||||
if(has_embedding_query && processed_search_fields.size() == 0) {
|
||||
raw_query = "*";
|
||||
}
|
||||
|
||||
if(!query_by_weights.empty() && processed_search_fields.size() != query_by_weights.size()) {
|
||||
std::string error = "Error, query_by_weights.size != query_by.size.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
|
||||
for(const std::string & field_name: processed_search_fields) {
|
||||
field search_field = search_schema.at(field_name);
|
||||
|
||||
if(!search_field.index) {
|
||||
std::string error = "Field `" + field_name + "` is marked as a non-indexed field in the schema.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
@ -1593,7 +1639,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
}
|
||||
|
||||
nlohmann::json result = nlohmann::json::object();
|
||||
|
||||
result["found"] = total_found;
|
||||
|
||||
if(exclude_fields.count("out_of") == 0) {
|
||||
@ -1799,7 +1844,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
wrapper_doc["geo_distance_meters"] = geo_distances;
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty()) {
|
||||
if(!vector_query.field_name.empty() && query == "*") {
|
||||
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
|
||||
}
|
||||
|
||||
@ -1983,7 +2028,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
result["request_params"] = nlohmann::json::object();
|
||||
result["request_params"]["collection_name"] = name;
|
||||
result["request_params"]["per_page"] = per_page;
|
||||
result["request_params"]["q"] = query;
|
||||
result["request_params"]["q"] = real_raw_query;
|
||||
|
||||
//long long int timeMillis = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - begin).count();
|
||||
//!LOG(INFO) << "Time taken for result calc: " << timeMillis << "us";
|
||||
@ -4586,3 +4631,33 @@ Option<bool> Collection::populate_include_exclude_fields_lk(const spp::sparse_ha
|
||||
std::shared_lock lock(mutex);
|
||||
return populate_include_exclude_fields(include_fields, exclude_fields, include_fields_full, exclude_fields_full);
|
||||
}
|
||||
|
||||
|
||||
Option<bool> Collection::embed_fields(nlohmann::json& document) {
|
||||
for(const auto& field : fields) {
|
||||
if(field.create_from.size() > 0) {
|
||||
if(TextEmbedderManager::model_dir.empty()) {
|
||||
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
|
||||
}
|
||||
std::string text_to_embed;
|
||||
for(const auto& field_name : field.create_from) {
|
||||
auto field_it = document.find(field_name);
|
||||
if(field_it != document.end()) {
|
||||
if(field_it->is_string()) {
|
||||
text_to_embed += field_it->get<std::string>() + " ";
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` is not a string.");
|
||||
}
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` not found in document.");
|
||||
}
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
|
||||
std::vector<float> embedding = embedder->Embed(text_to_embed);
|
||||
document[field.name] = embedding;
|
||||
}
|
||||
}
|
||||
return Option<bool>(true);
|
||||
}
|
@ -58,6 +58,14 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field_obj[fields::reference] = "";
|
||||
}
|
||||
|
||||
if(field_obj.count(fields::create_from) == 0) {
|
||||
field_obj[fields::create_from] = std::vector<std::string>();
|
||||
}
|
||||
|
||||
if(field_obj.count(fields::model_name) == 0) {
|
||||
field_obj[fields::model_name] = "";
|
||||
}
|
||||
|
||||
vector_distance_type_t vec_dist_type = vector_distance_type_t::cosine;
|
||||
|
||||
if(field_obj.count(fields::vec_dist) != 0) {
|
||||
@ -70,7 +78,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet],
|
||||
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
|
||||
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
|
||||
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference]);
|
||||
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::create_from],
|
||||
field_obj[fields::model_name]);
|
||||
|
||||
// value of `sort` depends on field type
|
||||
if(field_obj.count(fields::sort) == 0) {
|
||||
@ -200,7 +209,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
|
||||
for(size_t coll_index = 0; coll_index < num_collections; coll_index++) {
|
||||
const auto& collection_meta_json = collection_meta_jsons[coll_index];
|
||||
nlohmann::json collection_meta = nlohmann::json::parse(collection_meta_json, nullptr, false);
|
||||
|
||||
if(collection_meta.is_discarded()) {
|
||||
LOG(ERROR) << "Error while parsing collection meta, json: " << collection_meta_json;
|
||||
return Option<bool>(500, "Error while parsing collection meta.");
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <store.h>
|
||||
#include "field.h"
|
||||
#include "magic_enum.hpp"
|
||||
#include "text_embedder_manager.h"
|
||||
#include <stack>
|
||||
#include <collection_manager.h>
|
||||
|
||||
@ -662,6 +663,38 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json.count(fields::model_name) > 0 && field_json.count(fields::create_from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::model_name + "` can only be used with `" + fields::create_from + "`.");
|
||||
}
|
||||
|
||||
if(field_json.count(fields::create_from) != 0) {
|
||||
// If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model)
|
||||
field_json[fields::num_dim] = static_cast<unsigned int>(384);
|
||||
if(field_json.count(fields::model_name) != 0) {
|
||||
unsigned int num_dim = 0;
|
||||
if(!field_json[fields::model_name].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_name + "` must be a string.");
|
||||
}
|
||||
if(field_json[fields::model_name].get<std::string>().empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_name + "` must be a non-empty string.");
|
||||
}
|
||||
|
||||
if(TextEmbedder::is_model_valid(field_json[fields::model_name].get<std::string>(), num_dim)) {
|
||||
field_json[fields::num_dim] = num_dim;
|
||||
} else {
|
||||
return Option<bool>(400, "Property `" + fields::model_name + "` must be a valid model path.");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field_json[fields::create_from] = std::vector<std::string>();
|
||||
}
|
||||
|
||||
|
||||
if(field_json.count(fields::model_name) == 0) {
|
||||
field_json[fields::model_name] = "";
|
||||
}
|
||||
|
||||
|
||||
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
|
||||
|
||||
if(field_json.count(fields::num_dim) == 0) {
|
||||
@ -698,6 +731,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(field_json.count(fields::optional) == 0) {
|
||||
// dynamic type fields are always optional
|
||||
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
|
||||
@ -741,7 +775,8 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference])
|
||||
field_json[fields::reference], field_json[fields::create_from].get<std::vector<std::string>>(),
|
||||
field_json[fields::model_name])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
@ -352,3 +352,48 @@ size_t HttpClient::curl_write(char *contents, size_t size, size_t nmemb, std::st
|
||||
s->append(contents, size*nmemb);
|
||||
return size*nmemb;
|
||||
}
|
||||
|
||||
size_t HttpClient::curl_write_download(void *ptr, size_t size, size_t nmemb, FILE *stream) {
|
||||
size_t written = fwrite(ptr, size, nmemb, stream);
|
||||
return written;
|
||||
}
|
||||
|
||||
long HttpClient::download_file(const std::string& url, const std::string& file_path) {
|
||||
CURL *curl = curl_easy_init();
|
||||
|
||||
|
||||
if(curl == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
FILE *fp = fopen(file_path.c_str(), "wb");
|
||||
|
||||
if(fp == nullptr) {
|
||||
LOG(ERROR) << "Unable to open file for writing: " << file_path;
|
||||
return -1;
|
||||
}
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT_MS, 4000);
|
||||
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L);
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_download);
|
||||
// follow redirects
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
CURLcode res_code = curl_easy_perform(curl);
|
||||
|
||||
if(res_code != CURLE_OK) {
|
||||
LOG(ERROR) << "Unable to download file: " << url << " to " << file_path << " - " << curl_easy_strerror(res_code);
|
||||
return -1;
|
||||
}
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
|
||||
|
||||
curl_easy_cleanup(curl);
|
||||
fclose(fp);
|
||||
|
||||
return http_code;
|
||||
}
|
||||
|
||||
|
@ -3010,9 +3010,51 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
sort_order, field_values, geopoint_indices,
|
||||
curated_ids_sorted, all_result_ids, all_result_ids_len, groups_processed);
|
||||
|
||||
if(!vector_query.field_name.empty()) {
|
||||
// check at least one of sort fields is text match
|
||||
bool has_text_match = false;
|
||||
for(auto& sort_field : sort_fields_std) {
|
||||
if(sort_field.name == sort_field_const::text_match) {
|
||||
has_text_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(has_text_match) {
|
||||
VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
auto k = std::max<size_t>(vector_query.k, per_page * page);
|
||||
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, filterFunctor);
|
||||
} else {
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, filterFunctor);
|
||||
}
|
||||
|
||||
for (const auto& dist_label : dist_labels) {
|
||||
uint32 seq_id = dist_label.second;
|
||||
|
||||
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
|
||||
dist_label.first;
|
||||
auto score = (1.0 - vec_dist_score) * 100000000000.0;
|
||||
|
||||
auto found = topster->kv_map.find(seq_id);
|
||||
|
||||
if (found != topster->kv_map.end() && found->second->match_score_index >= 0 && found->second->match_score_index <= 2) {
|
||||
found->second->scores[found->second->match_score_index] += score;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/*auto timeMillis0 = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::high_resolution_clock::now() - begin0).count();
|
||||
LOG(INFO) << "Time taken for multi-field aggregation: " << timeMillis0 << "ms";*/
|
||||
|
||||
}
|
||||
|
||||
//LOG(INFO) << "topster size: " << topster->size;
|
||||
|
156
src/text_embedder.cpp
Normal file
156
src/text_embedder.cpp
Normal file
@ -0,0 +1,156 @@
|
||||
#include "text_embedder.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "logger.h"
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <filesystem>
|
||||
|
||||
TextEmbedder::TextEmbedder(const std::string& model_path) {
|
||||
|
||||
// create environment
|
||||
Ort::SessionOptions session_options;
|
||||
std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_path);
|
||||
LOG(INFO) << "Loading model from: " << abs_path;
|
||||
session_ = new Ort::Session(env_, abs_path.c_str(), session_options);
|
||||
std::ifstream stream("vocab.txt");
|
||||
std::stringstream ss;
|
||||
ss << stream.rdbuf();
|
||||
auto vocab_ = ss.str();
|
||||
tokenizer_ = new BertTokenizer(vocab_, true, true, ustring("[UNK]"), ustring("[SEP]"), ustring("[PAD]"),
|
||||
ustring("[CLS]"), ustring("[MASK]"), true, true, ustring("##"),512, std::string("longest_first"));
|
||||
|
||||
auto output_tensor_count = session_->GetOutputCount();
|
||||
for (size_t i = 0; i < output_tensor_count; i++) {
|
||||
auto shape = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
output_tensor_name = std::string(session_->GetOutputNameAllocated(i, allocator).get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
encoded_input_t TextEmbedder::Encode(const std::string& text) {
|
||||
|
||||
auto encoded = tokenizer_->Encode(tokenizer_->Tokenize(ustring(text)));
|
||||
auto input_ids = tokenizer_->AddSpecialToken(encoded);
|
||||
auto token_type_ids = tokenizer_->GenerateTypeId(encoded);
|
||||
auto attention_mask = std::vector<int64_t>(input_ids.size(), 1);
|
||||
return {input_ids, token_type_ids, attention_mask};
|
||||
}
|
||||
|
||||
|
||||
std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<float>>& inputs) {
|
||||
|
||||
std::vector<float> pooled_output;
|
||||
for (int i = 0; i < inputs[0].size(); i++) {
|
||||
float sum = 0;
|
||||
for (int j = 0; j < inputs.size(); j++) {
|
||||
sum += inputs[j][i];
|
||||
}
|
||||
pooled_output.push_back(sum / inputs.size());
|
||||
}
|
||||
return pooled_output;
|
||||
}
|
||||
|
||||
std::vector<float> TextEmbedder::Embed(const std::string& text) {
|
||||
auto encoded_input = Encode(text);
|
||||
// create input tensor object from data values
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
|
||||
std::vector<Ort::Value> input_tensors;
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<const char*> input_node_names = {"input_ids", "attention_mask", "token_type_ids"};
|
||||
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.input_ids.size())});
|
||||
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.attention_mask.size())});
|
||||
input_shapes.push_back({1, static_cast<int64_t>(encoded_input.token_type_ids.size())});
|
||||
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.input_ids.data(), encoded_input.input_ids.size(), input_shapes[0].data(), input_shapes[0].size()));
|
||||
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.attention_mask.data(), encoded_input.attention_mask.size(), input_shapes[1].data(), input_shapes[1].size()));
|
||||
input_tensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, encoded_input.token_type_ids.data(), encoded_input.token_type_ids.size(), input_shapes[2].data(), input_shapes[2].size()));
|
||||
|
||||
//LOG(INFO) << "Running model";
|
||||
// create output tensor object
|
||||
std::vector<const char*> output_node_names = {output_tensor_name.c_str()};
|
||||
std::vector<int64_t> output_node_dims {1, static_cast<int64_t>(encoded_input.input_ids.size()), 384}; // batch_size x seq_length x hidden_size
|
||||
auto output_tensor = session_->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_tensors.size(), output_node_names.data(), output_node_names.size());
|
||||
std::vector<std::vector<float>> output;
|
||||
float* data = output_tensor[0].GetTensorMutableData<float>();
|
||||
// print output tensor shape
|
||||
auto shape = output_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
//LOG(INFO) << "Output tensor size: " << shape[0] << " x " << shape[1] << " x " << shape[2];
|
||||
for (int i = 0; i < shape[1]; i++) {
|
||||
std::vector<float> temp;
|
||||
for (int j = 0; j < shape[2]; j++) {
|
||||
temp.push_back(data[i * shape[2] + j]);
|
||||
}
|
||||
output.push_back(temp);
|
||||
}
|
||||
//LOG(INFO) << "Mean pooling";
|
||||
auto pooled_output = mean_pooling(output);
|
||||
|
||||
|
||||
|
||||
return pooled_output;
|
||||
}
|
||||
|
||||
TextEmbedder::~TextEmbedder() {
|
||||
delete tokenizer_;
|
||||
delete session_;
|
||||
}
|
||||
|
||||
|
||||
bool TextEmbedder::is_model_valid(const std::string& model_path, unsigned int& num_dims) {
|
||||
LOG(INFO) << "Loading model: " << model_path;
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::Env env;
|
||||
std::string abs_path = TextEmbedderManager::get_absolute_model_path(model_path);
|
||||
|
||||
if(!std::filesystem::exists(abs_path)) {
|
||||
LOG(ERROR) << "Model file not found: " << abs_path;
|
||||
return false;
|
||||
}
|
||||
|
||||
Ort::Session session(env, abs_path.c_str(), session_options);
|
||||
if(session.GetInputCount() != 3) {
|
||||
LOG(ERROR) << "Invalid model: input count is not 3";
|
||||
return false;
|
||||
}
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto input_ids_name = session.GetInputNameAllocated(0, allocator);
|
||||
if (std::strcmp(input_ids_name.get(), "input_ids") != 0) {
|
||||
LOG(ERROR) << "Invalid model: input_ids tensor not found";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto attention_mask_name = session.GetInputNameAllocated(1, allocator);
|
||||
if (std::strcmp(attention_mask_name.get(), "attention_mask") != 0) {
|
||||
LOG(ERROR) << "Invalid model: attention_mask tensor not found";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto token_type_ids_name = session.GetInputNameAllocated(2, allocator);
|
||||
if (std::strcmp(token_type_ids_name.get(), "token_type_ids") != 0) {
|
||||
LOG(ERROR) << "Invalid model: token_type_ids tensor not found";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto output_tensor_count = session.GetOutputCount();
|
||||
bool found_output_tensor = false;
|
||||
for (size_t i = 0; i < output_tensor_count; i++) {
|
||||
auto shape = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
|
||||
num_dims = shape[2];
|
||||
found_output_tensor = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_output_tensor) {
|
||||
LOG(ERROR) << "Invalid model: Output tensor not found";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "core_api.h"
|
||||
#include "ratelimit_manager.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "typesense_server_utils.h"
|
||||
#include "file_utils.h"
|
||||
#include "threadpool.h"
|
||||
@ -66,6 +67,7 @@ void init_cmdline_options(cmdline::parser & options, int argc, char **argv) {
|
||||
options.set_program_name("./typesense-server");
|
||||
|
||||
options.add<std::string>("data-dir", 'd', "Directory where data will be stored.", true);
|
||||
options.add<std::string>("model-dir", '\0', "Directory where text embedding models will be stored.", false, "");
|
||||
options.add<std::string>("api-key", 'a', "API key that allows all operations.", true);
|
||||
options.add<std::string>("search-only-api-key", 's', "[DEPRECATED: use API key management end-point] API key that allows only searches.", false);
|
||||
|
||||
@ -452,6 +454,21 @@ int run_server(const Config & config, const std::string & version, void (*master
|
||||
LOG(INFO) << "Failed to initialize rate limit manager: " << rate_limit_manager_init.error();
|
||||
}
|
||||
|
||||
if(config.get_model_dir().size() > 0) {
|
||||
LOG(INFO) << "Loading text embedding models from " << config.get_model_dir();
|
||||
TextEmbedderManager::get_instance().set_model_dir(config.get_model_dir());
|
||||
|
||||
LOG(INFO) << "Downloading default model and vocab";
|
||||
long res = httpClient.download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
if(res != 200) {
|
||||
LOG(INFO) << "Failed to download default model: " << res;
|
||||
}
|
||||
|
||||
res = httpClient.download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::get_absolute_vocab_path());
|
||||
if(res != 200) {
|
||||
LOG(INFO) << "Failed to download default vocab: " << res;
|
||||
}
|
||||
}
|
||||
// first we start the peering service
|
||||
|
||||
ReplicationState replication_state(server, batch_indexer, &store,
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <algorithm>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "http_client.h"
|
||||
|
||||
class CollectionAllFieldsTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -27,6 +29,7 @@ protected:
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
system("mkdir -p models");
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
@ -59,6 +62,8 @@ TEST_F(CollectionAllFieldsTest, IndexDocsWithoutSchema) {
|
||||
while (std::getline(infile, json_line)) {
|
||||
nlohmann::json document = nlohmann::json::parse(json_line);
|
||||
Option<nlohmann::json> add_op = coll1->add(document.dump());
|
||||
|
||||
LOG(INFO) << "Add op: " << add_op.error();
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
@ -1586,3 +1591,104 @@ TEST_F(CollectionAllFieldsTest, FieldNameMatchingRegexpShouldNotBeIndexedInNonAu
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) {
|
||||
TextEmbedderManager::model_dir = "./models";
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["create_from"] = {"name"};
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
auto arr = nlohmann::json::array();
|
||||
arr.push_back(field_json);
|
||||
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `create_from` can only be used with array of string fields.", field_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) {
|
||||
TextEmbedderManager::model_dir = std::string();
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["create_from"] = {"name"};
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
auto arr = nlohmann::json::array();
|
||||
arr.push_back(field_json);
|
||||
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Text embedding is not enabled. Please set `model-dir` at startup.", field_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, CreateFromNotArray) {
|
||||
TextEmbedderManager::model_dir = "./models";
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["create_from"] = "name";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
auto arr = nlohmann::json::array();
|
||||
arr.push_back(field_json);
|
||||
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `create_from` must be an array.", field_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) {
|
||||
TextEmbedderManager::model_dir = "./models";
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["model_name"] = "model";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
auto arr = nlohmann::json::array();
|
||||
arr.push_back(field_json);
|
||||
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `model_name` can only be used with `create_from`.", field_op.error());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) {
|
||||
|
||||
TextEmbedderManager::model_dir = "./models/";
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::get_absolute_vocab_path());
|
||||
|
||||
field embedding = field("embedding", field_types::FLOAT_ARRAY, false);
|
||||
embedding.create_from.push_back("name");
|
||||
std::vector<field> fields = {field("name", field_types::STRING, false),
|
||||
embedding};
|
||||
auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO);
|
||||
|
||||
ASSERT_TRUE(obj_coll_op.ok());
|
||||
Collection* obj_coll = obj_coll_op.get();
|
||||
|
||||
nlohmann::json doc1;
|
||||
doc1["name"] = "One Two Three";
|
||||
|
||||
auto add_res = obj_coll->add(doc1.dump());
|
||||
|
||||
ASSERT_TRUE(add_res.ok());
|
||||
ASSERT_TRUE(add_res.get()["name"].is_string());
|
||||
ASSERT_TRUE(add_res.get()["embedding"].is_array());
|
||||
ASSERT_EQ(384, add_res.get()["embedding"].size());
|
||||
|
||||
// delete models folder
|
||||
system("rm -rf ./models");
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <algorithm>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "http_client.h"
|
||||
|
||||
class CollectionTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -62,6 +64,7 @@ protected:
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
system("mkdir -p models");
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
@ -4605,4 +4608,153 @@ TEST_F(CollectionTest, WildcardHighlightFullFields) {
|
||||
|
||||
ASSERT_EQ(0, result["hits"][0]["highlight"]["user"]["bio"].count("value"));
|
||||
ASSERT_EQ(0, result["hits"][0]["highlight"]["user_name"].count("value"));
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionTest, SemanticSearchTest) {
|
||||
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::model_dir = "./models/";
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
Collection* coll = op.get();
|
||||
nlohmann::json object;
|
||||
object["name"] = "apple";
|
||||
auto add_op = coll->add(object.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
ASSERT_EQ("apple", add_op.get()["name"]);
|
||||
ASSERT_EQ(384, add_op.get()["embedding"].size());
|
||||
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
|
||||
auto search_res_op = coll->search("apple", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
auto search_res = search_res_op.get();
|
||||
ASSERT_EQ(1, search_res["found"].get<size_t>());
|
||||
ASSERT_EQ(1, search_res["hits"].size());
|
||||
ASSERT_EQ("apple", search_res["hits"][0]["document"]["name"].get<std::string>());
|
||||
ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size());
|
||||
|
||||
// delete models folder
|
||||
system("rm -rf ./models");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, InvalidSemanticSearch) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::model_dir = "./models/";
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
LOG(INFO) << "op.error(): " << op.error();
|
||||
ASSERT_TRUE(op.ok());
|
||||
Collection* coll = op.get();
|
||||
nlohmann::json object;
|
||||
object["name"] = "apple";
|
||||
auto add_op = coll->add(object.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
LOG(INFO) << "add_op.get(): " << add_op.get().dump();
|
||||
ASSERT_EQ("apple", add_op.get()["name"]);
|
||||
ASSERT_EQ(384, add_op.get()["embedding"].size());
|
||||
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
|
||||
auto search_res_op = coll->search("apple", {"embedding", "embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
|
||||
// delete models folder
|
||||
system("rm -rf ./models");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, HybridSearch) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::model_dir = "./models/";
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
LOG(INFO) << collectionManager.get_collection("objects")->get_summary_json();
|
||||
ASSERT_TRUE(op.ok());
|
||||
Collection* coll = op.get();
|
||||
nlohmann::json object;
|
||||
object["name"] = "apple";
|
||||
auto add_op = coll->add(object.dump());
|
||||
LOG(INFO) << "hybrid search";
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
ASSERT_EQ("apple", add_op.get()["name"]);
|
||||
ASSERT_EQ(384, add_op.get()["embedding"].size());
|
||||
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
LOG(INFO) << "hybrid search 2";
|
||||
auto search_res_op = coll->search("apple", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
LOG(INFO) << "hybrid search 3";
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
auto search_res = search_res_op.get();
|
||||
ASSERT_EQ(1, search_res["found"].get<size_t>());
|
||||
ASSERT_EQ(1, search_res["hits"].size());
|
||||
ASSERT_EQ("apple", search_res["hits"][0]["document"]["name"].get<std::string>());
|
||||
ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size());
|
||||
|
||||
// delete models folder
|
||||
system("rm -rf ./models");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, EmbedFielsTest) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::model_dir = "./models/";
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_MODEL_URL, TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::DEFAULT_MODEL_NAME));
|
||||
HttpClient::get_instance().download_file(TextEmbedderManager::DEFAULT_VOCAB_URL, TextEmbedderManager::model_dir + TextEmbedderManager::get_absolute_vocab_path());
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
Collection* coll = op.get();
|
||||
|
||||
nlohmann::json object = R"({
|
||||
"name": "apple"
|
||||
})"_json;
|
||||
|
||||
auto embed_op = coll->embed_fields(object);
|
||||
|
||||
ASSERT_TRUE(embed_op.ok());
|
||||
|
||||
ASSERT_EQ("apple", object["name"]);
|
||||
ASSERT_EQ(384, object["embedding"].get<std::vector<float>>().size());
|
||||
|
||||
// delete models folder
|
||||
system("rm -rf ./models");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user