mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 05:32:30 +08:00
Review changes IV
This commit is contained in:
parent
b713a12be1
commit
09926b0d69
@ -8,6 +8,7 @@
|
||||
#include <fstream>
|
||||
#include "logger.h"
|
||||
#include "http_client.h"
|
||||
#include "option.h"
|
||||
#include "text_embedder.h"
|
||||
|
||||
struct text_embedding_model {
|
||||
@ -34,7 +35,7 @@ public:
|
||||
TextEmbedderManager(const TextEmbedderManager&) = delete;
|
||||
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
|
||||
|
||||
TextEmbedder* get_text_embedder(const nlohmann::json& model_config);
|
||||
Option<TextEmbedder*> get_text_embedder(const nlohmann::json& model_config);
|
||||
void delete_text_embedder(const std::string& model_path);
|
||||
void delete_all_text_embedders();
|
||||
|
||||
@ -59,7 +60,7 @@ public:
|
||||
static const std::string get_model_name_without_namespace(const std::string& model_name);
|
||||
static const std::string get_model_subdir(const std::string& model_name);
|
||||
static const bool check_md5(const std::string& file_path, const std::string& target_md5);
|
||||
void download_public_model(const std::string& model_name);
|
||||
Option<bool> download_public_model(const std::string& model_name);
|
||||
|
||||
const bool is_public_model(const std::string& model_name);
|
||||
|
||||
|
@ -242,13 +242,13 @@ nlohmann::json Collection::get_summary_json() const {
|
||||
if(coll_field.embed.count(fields::from) != 0) {
|
||||
field_json[fields::embed] = coll_field.embed;
|
||||
|
||||
if(field_json[fields::embed].count(fields::api_key) != 0) {
|
||||
if(field_json[fields::embed].count(fields::model_config) != 0 && field_json[fields::embed][fields::model_config].count(fields::api_key) != 0) {
|
||||
// hide api key with * except first 3 chars
|
||||
std::string api_key = field_json[fields::embed][fields::api_key];
|
||||
std::string api_key = field_json[fields::embed][fields::model_config][fields::api_key];
|
||||
if(api_key.size() > 3) {
|
||||
field_json[fields::embed][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*');
|
||||
field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*');
|
||||
} else {
|
||||
field_json[fields::embed][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*');
|
||||
field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*');
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1189,7 +1189,19 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]);
|
||||
auto embedder_op = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<nlohmann::json>(400, embedder_op.error());
|
||||
}
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_openai()) {
|
||||
// return error if prefix search is used with openai embedder
|
||||
if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) {
|
||||
std::string error = "Prefix search is not supported for OpenAI embedder.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
}
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
|
||||
auto embedding_op = embedder->Embed(embed_query);
|
||||
|
@ -6365,8 +6365,13 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
text_to_embed.push_back(text);
|
||||
}
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
auto embedding_op = embedder->batch_embed(text_to_embed);
|
||||
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(400, embedder_op.error());
|
||||
}
|
||||
|
||||
auto embedding_op = embedder_op.get()->batch_embed(text_to_embed);
|
||||
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<bool>(400, embedding_op.error());
|
||||
|
@ -73,7 +73,6 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
} else {
|
||||
LOG(INFO) << "Embedding text: " << text;
|
||||
auto encoded_input = tokenizer_->Encode(text);
|
||||
// create input tensor object from data values
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
@ -160,7 +159,11 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
|
||||
LOG(INFO) << "Loading model: " << model_name;
|
||||
|
||||
if(TextEmbedderManager::get_instance().is_public_model(model_name)) {
|
||||
TextEmbedderManager::get_instance().download_public_model(model_name);
|
||||
auto res = TextEmbedderManager::get_instance().download_public_model(model_name);
|
||||
if(!res.ok()) {
|
||||
LOG(ERROR) << res.error();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,21 +6,24 @@ TextEmbedderManager& TextEmbedderManager::get_instance() {
|
||||
return instance;
|
||||
}
|
||||
|
||||
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
|
||||
Option<TextEmbedder*>TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
|
||||
std::unique_lock<std::mutex> lock(text_embedders_mutex);
|
||||
const std::string& model_name = model_config.at("model_name");
|
||||
if(text_embedders[model_name] == nullptr) {
|
||||
if(model_config.count("api_key") == 0) {
|
||||
if(is_public_model(model_name)) {
|
||||
// download the model if it doesn't exist
|
||||
download_public_model(model_name);
|
||||
auto res = download_public_model(model_name);
|
||||
if(!res.ok()) {
|
||||
return Option<TextEmbedder*>(res.code(), res.error());
|
||||
}
|
||||
}
|
||||
text_embedders[model_name] = std::make_shared<TextEmbedder>(get_model_name_without_namespace(model_name));
|
||||
} else {
|
||||
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_config.at("api_key").get<std::string>());
|
||||
}
|
||||
}
|
||||
return text_embedders[model_name].get();
|
||||
return Option<TextEmbedder*>(text_embedders[model_name].get());
|
||||
}
|
||||
|
||||
void TextEmbedderManager::delete_text_embedder(const std::string& model_path) {
|
||||
@ -119,7 +122,7 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st
|
||||
}
|
||||
return res.str() == target_md5;
|
||||
}
|
||||
void TextEmbedderManager::download_public_model(const std::string& model_name) {
|
||||
Option<bool> TextEmbedderManager::download_public_model(const std::string& model_name) {
|
||||
HttpClient& httpClient = HttpClient::get_instance();
|
||||
auto model = public_models[model_name];
|
||||
auto actual_model_name = get_model_name_without_namespace(model_name);
|
||||
@ -127,6 +130,7 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) {
|
||||
long res = httpClient.download_file(get_model_url(model), get_absolute_model_path(actual_model_name));
|
||||
if(res != 200) {
|
||||
LOG(INFO) << "Failed to download public model " << model_name << ": " << res;
|
||||
return Option<bool>(400, "Failed to download model file");
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,9 +138,11 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) {
|
||||
long res = httpClient.download_file(get_vocab_url(model), get_absolute_vocab_path(actual_model_name, model.vocab_file_name));
|
||||
if(res != 200) {
|
||||
LOG(INFO) << "Failed to download default vocab " << model_name << ": " << res;
|
||||
return Option<bool>(400, "Failed to download vocab file");
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
const bool TextEmbedderManager::is_public_model(const std::string& model_name) {
|
||||
|
@ -4892,6 +4892,21 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) {
|
||||
ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, WrongTypeInEmbedFrom) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "category", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": [1122], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_FALSE(op.ok());
|
||||
ASSERT_EQ("Property `embed.from` must contain only field names as strings.", op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, WrongTypeForEmbedding) {
|
||||
nlohmann::json schema = R"({
|
||||
@ -4991,7 +5006,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -5001,15 +5016,13 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
}
|
||||
|
||||
auto api_key = std::string(std::getenv("api_key"));
|
||||
schema["fields"][1]["model_config"]["api_key"] = api_key;
|
||||
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto summary = op.get()->get_summary_json();
|
||||
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_config"]["model_name"]);
|
||||
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["embed"]["model_config"]["model_name"]);
|
||||
ASSERT_EQ(1536, summary["fields"][1]["num_dim"]);
|
||||
// make sure api_key is <hidden>
|
||||
ASSERT_EQ("<hidden>", summary["fields"][1]["model_config"]["api_key"]);
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "butter";
|
||||
@ -5019,7 +5032,68 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
ASSERT_EQ(1536, add_op.get()["embedding"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, MoreThganOneEmbeddingField) {
|
||||
TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
if (std::getenv("api_key") == nullptr) {
|
||||
LOG(INFO) << "Skipping test as api_key is not set.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto api_key = std::string(std::getenv("api_key"));
|
||||
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto summary = op.get()->get_summary_json();
|
||||
// hide api key with * after first 3 characters
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
if (std::getenv("api_key") == nullptr) {
|
||||
LOG(INFO) << "Skipping test as api_key is not set.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto api_key = std::string(std::getenv("api_key"));
|
||||
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "butter";
|
||||
|
||||
auto add_op = op.get()->add(doc.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error());
|
||||
|
||||
search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionTest, MoreThanOneEmbeddingField) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
|
Loading…
x
Reference in New Issue
Block a user