Review changes IV

This commit is contained in:
ozanarmagan 2023-05-10 17:35:31 +03:00
parent b713a12be1
commit 09926b0d69
6 changed files with 122 additions and 21 deletions

View File

@ -8,6 +8,7 @@
#include <fstream>
#include "logger.h"
#include "http_client.h"
#include "option.h"
#include "text_embedder.h"
struct text_embedding_model {
@ -34,7 +35,7 @@ public:
TextEmbedderManager(const TextEmbedderManager&) = delete;
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
TextEmbedder* get_text_embedder(const nlohmann::json& model_config);
Option<TextEmbedder*> get_text_embedder(const nlohmann::json& model_config);
void delete_text_embedder(const std::string& model_path);
void delete_all_text_embedders();
@ -59,7 +60,7 @@ public:
static const std::string get_model_name_without_namespace(const std::string& model_name);
static const std::string get_model_subdir(const std::string& model_name);
static const bool check_md5(const std::string& file_path, const std::string& target_md5);
void download_public_model(const std::string& model_name);
Option<bool> download_public_model(const std::string& model_name);
const bool is_public_model(const std::string& model_name);

View File

@ -242,13 +242,13 @@ nlohmann::json Collection::get_summary_json() const {
if(coll_field.embed.count(fields::from) != 0) {
field_json[fields::embed] = coll_field.embed;
if(field_json[fields::embed].count(fields::api_key) != 0) {
if(field_json[fields::embed].count(fields::model_config) != 0 && field_json[fields::embed][fields::model_config].count(fields::api_key) != 0) {
// hide api key with * except first 3 chars
std::string api_key = field_json[fields::embed][fields::api_key];
std::string api_key = field_json[fields::embed][fields::model_config][fields::api_key];
if(api_key.size() > 3) {
field_json[fields::embed][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*');
field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*');
} else {
field_json[fields::embed][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*');
field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*');
}
}
}
@ -1189,7 +1189,19 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]);
auto embedder_op = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<nlohmann::json>(400, embedder_op.error());
}
auto embedder = embedder_op.get();
if(embedder->is_openai()) {
// return error if prefix search is used with openai embedder
if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) {
std::string error = "Prefix search is not supported for OpenAI embedder.";
return Option<nlohmann::json>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
auto embedding_op = embedder->Embed(embed_query);

View File

@ -6365,8 +6365,13 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
text_to_embed.push_back(text);
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
auto embedding_op = embedder->batch_embed(text_to_embed);
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<bool>(400, embedder_op.error());
}
auto embedding_op = embedder_op.get()->batch_embed(text_to_embed);
if(!embedding_op.ok()) {
return Option<bool>(400, embedding_op.error());

View File

@ -73,7 +73,6 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
}
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
} else {
LOG(INFO) << "Embedding text: " << text;
auto encoded_input = tokenizer_->Encode(text);
// create input tensor object from data values
Ort::AllocatorWithDefaultOptions allocator;
@ -160,7 +159,11 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
LOG(INFO) << "Loading model: " << model_name;
if(TextEmbedderManager::get_instance().is_public_model(model_name)) {
TextEmbedderManager::get_instance().download_public_model(model_name);
auto res = TextEmbedderManager::get_instance().download_public_model(model_name);
if(!res.ok()) {
LOG(ERROR) << res.error();
return false;
}
}

View File

@ -6,21 +6,24 @@ TextEmbedderManager& TextEmbedderManager::get_instance() {
return instance;
}
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
Option<TextEmbedder*>TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
std::unique_lock<std::mutex> lock(text_embedders_mutex);
const std::string& model_name = model_config.at("model_name");
if(text_embedders[model_name] == nullptr) {
if(model_config.count("api_key") == 0) {
if(is_public_model(model_name)) {
// download the model if it doesn't exist
download_public_model(model_name);
auto res = download_public_model(model_name);
if(!res.ok()) {
return Option<TextEmbedder*>(res.code(), res.error());
}
}
text_embedders[model_name] = std::make_shared<TextEmbedder>(get_model_name_without_namespace(model_name));
} else {
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_config.at("api_key").get<std::string>());
}
}
return text_embedders[model_name].get();
return Option<TextEmbedder*>(text_embedders[model_name].get());
}
void TextEmbedderManager::delete_text_embedder(const std::string& model_path) {
@ -119,7 +122,7 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st
}
return res.str() == target_md5;
}
void TextEmbedderManager::download_public_model(const std::string& model_name) {
Option<bool> TextEmbedderManager::download_public_model(const std::string& model_name) {
HttpClient& httpClient = HttpClient::get_instance();
auto model = public_models[model_name];
auto actual_model_name = get_model_name_without_namespace(model_name);
@ -127,6 +130,7 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) {
long res = httpClient.download_file(get_model_url(model), get_absolute_model_path(actual_model_name));
if(res != 200) {
LOG(INFO) << "Failed to download public model " << model_name << ": " << res;
return Option<bool>(400, "Failed to download model file");
}
}
@ -134,9 +138,11 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) {
long res = httpClient.download_file(get_vocab_url(model), get_absolute_vocab_path(actual_model_name, model.vocab_file_name));
if(res != 200) {
LOG(INFO) << "Failed to download default vocab " << model_name << ": " << res;
return Option<bool>(400, "Failed to download vocab file");
}
}
return Option<bool>(true);
}
const bool TextEmbedderManager::is_public_model(const std::string& model_name) {

View File

@ -4892,6 +4892,21 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) {
ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error());
}
TEST_F(CollectionTest, WrongTypeInEmbedFrom) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "category", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": [1122], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_FALSE(op.ok());
ASSERT_EQ("Property `embed.from` must contain only field names as strings.", op.error());
}
TEST_F(CollectionTest, WrongTypeForEmbedding) {
nlohmann::json schema = R"({
@ -4991,7 +5006,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
]
})"_json;
@ -5001,15 +5016,13 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
}
auto api_key = std::string(std::getenv("api_key"));
schema["fields"][1]["model_config"]["api_key"] = api_key;
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
auto summary = op.get()->get_summary_json();
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_config"]["model_name"]);
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["embed"]["model_config"]["model_name"]);
ASSERT_EQ(1536, summary["fields"][1]["num_dim"]);
// make sure api_key is <hidden>
ASSERT_EQ("<hidden>", summary["fields"][1]["model_config"]["api_key"]);
nlohmann::json doc;
doc["name"] = "butter";
@ -5019,7 +5032,68 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
ASSERT_EQ(1536, add_op.get()["embedding"].size());
}
TEST_F(CollectionTest, MoreThganOneEmbeddingField) {
TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
]
})"_json;
if (std::getenv("api_key") == nullptr) {
LOG(INFO) << "Skipping test as api_key is not set.";
return;
}
auto api_key = std::string(std::getenv("api_key"));
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
auto summary = op.get()->get_summary_json();
// hide api key with * after first 3 characters
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
}
TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}}
]
})"_json;
if (std::getenv("api_key") == nullptr) {
LOG(INFO) << "Skipping test as api_key is not set.";
return;
}
auto api_key = std::string(std::getenv("api_key"));
schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
nlohmann::json doc;
doc["name"] = "butter";
auto add_op = op.get()->add(doc.dump());
ASSERT_TRUE(add_op.ok());
spp::sparse_hash_set<std::string> dummy_include_exclude;
auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error());
search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
}
TEST_F(CollectionTest, MoreThanOneEmbeddingField) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [