mirror of
https://github.com/typesense/typesense.git
synced 2025-04-20 12:18:31 +08:00
allow updating remote model api_key (#1944)
Some checks failed
tests / test (push) Has been cancelled
Some checks failed
tests / test (push) Has been cancelled
* allow updating remote model api_key * add nullptr checks and unique_lock * refactor field updation flow * Revert "refactor field updation flow" This reverts commit a6a9847be80ce3d8cee7ea944f5e9453091a949f. * refactor api_key updation flow * add locks in remote embedder, update api_key after validation * fix parsing of embed field * fix wrong params and parse * update minor changes, update entry in text_embedders * add comment * update embeded_fields with new api_key * add lock for updating embedded fields --------- Co-authored-by: Kishore Nallan <kishorenc@gmail.com>
This commit is contained in:
parent
87a48a5ef3
commit
b256736986
@ -266,6 +266,7 @@ private:
|
||||
std::vector<field>& addition_fields,
|
||||
std::vector<field>& reindex_fields,
|
||||
std::vector<field>& del_fields,
|
||||
std::vector<field>& update_fields,
|
||||
std::string& fallback_field_type);
|
||||
|
||||
void process_filter_overrides(std::vector<const override_t*>& filter_overrides,
|
||||
@ -425,6 +426,8 @@ public:
|
||||
|
||||
void update_metadata(const nlohmann::json& meta);
|
||||
|
||||
Option<bool> update_apikey(const nlohmann::json& model_config, const std::string& field_name);
|
||||
|
||||
Option<doc_seq_id_t> to_doc(const std::string& json_str, nlohmann::json& document,
|
||||
const index_operation_t& operation,
|
||||
const DIRTY_VALUES dirty_values,
|
||||
|
@ -232,5 +232,5 @@ public:
|
||||
|
||||
bool is_valid_api_key_collection(const std::vector<std::string>& api_key_collections, Collection* coll) const;
|
||||
|
||||
bool update_collection_metadata(const std::string& collection, const nlohmann::json& metadata);
|
||||
Option<bool> update_collection_metadata(const std::string& collection, const nlohmann::json& metadata);
|
||||
};
|
||||
|
@ -84,6 +84,8 @@ public:
|
||||
Option<bool> validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims);
|
||||
Option<bool> validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims);
|
||||
|
||||
Option<bool> update_remote_model_apikey(const nlohmann::json& model_config, const std::string& new_apikey);
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<TextEmbedder>> _get_text_embedders() {
|
||||
return text_embedders;
|
||||
}
|
||||
|
@ -37,6 +37,11 @@ class TextEmbedder {
|
||||
const TokenizerType get_tokenizer_type() {
|
||||
return tokenizer_->get_tokenizer_type();
|
||||
}
|
||||
|
||||
bool update_remote_embedder_apikey(const std::string& api_key) {
|
||||
return remote_embedder_->update_api_key(api_key);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Ort::Session> session_;
|
||||
std::shared_ptr<Ort::Env> env_;
|
||||
|
@ -28,6 +28,7 @@ class RemoteEmbedder {
|
||||
protected:
|
||||
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
std::shared_mutex mutex;
|
||||
public:
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
|
||||
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
|
||||
@ -39,6 +40,7 @@ class RemoteEmbedder {
|
||||
raft_server = rs;
|
||||
}
|
||||
virtual ~RemoteEmbedder() = default;
|
||||
virtual bool update_api_key(const std::string& api_key) = 0;
|
||||
|
||||
};
|
||||
|
||||
@ -51,6 +53,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
bool has_custom_dims;
|
||||
size_t num_dims;
|
||||
std::string openai_url = "https://api.openai.com";
|
||||
|
||||
static std::string get_openai_create_embedding_url(const std::string& openai_url) {
|
||||
return openai_url.back() == '/' ? openai_url + OPENAI_CREATE_EMBEDDING : openai_url + "/" + OPENAI_CREATE_EMBEDDING;
|
||||
}
|
||||
@ -62,6 +65,12 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
|
||||
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
|
||||
static std::string get_model_key(const nlohmann::json& model_config);
|
||||
|
||||
bool update_api_key(const std::string& apikey) override {
|
||||
std::lock_guard<std::shared_mutex> lock(mutex);
|
||||
api_key = apikey;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -72,7 +81,6 @@ class GoogleEmbedder : public RemoteEmbedder {
|
||||
inline static constexpr short GOOGLE_EMBEDDING_DIM = 768;
|
||||
inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=";
|
||||
std::string google_api_key;
|
||||
|
||||
public:
|
||||
GoogleEmbedder(const std::string& google_api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
|
||||
@ -81,6 +89,11 @@ class GoogleEmbedder : public RemoteEmbedder {
|
||||
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
|
||||
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
|
||||
static std::string get_model_key(const nlohmann::json& model_config);
|
||||
bool update_api_key(const std::string& apikey) override {
|
||||
std::lock_guard<std::shared_mutex> lock(mutex);
|
||||
google_api_key = apikey;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -93,6 +106,7 @@ class GCPEmbedder : public RemoteEmbedder {
|
||||
std::string client_id;
|
||||
std::string client_secret;
|
||||
std::string model_name;
|
||||
|
||||
inline static const std::string GCP_EMBEDDING_BASE_URL = "https://us-central1-aiplatform.googleapis.com/v1/projects/";
|
||||
inline static const std::string GCP_EMBEDDING_PATH = "/locations/us-central1/publishers/google/models/";
|
||||
inline static const std::string GCP_EMBEDDING_PREDICT = ":predict";
|
||||
@ -110,6 +124,9 @@ class GCPEmbedder : public RemoteEmbedder {
|
||||
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
|
||||
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
|
||||
static std::string get_model_key(const nlohmann::json& model_config);
|
||||
bool update_api_key(const std::string& api_key) override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
@ -4813,6 +4813,46 @@ void Collection::update_metadata(const nlohmann::json& meta) {
|
||||
metadata = meta;
|
||||
}
|
||||
|
||||
Option<bool> Collection::update_apikey(const nlohmann::json& model_config, const std::string& field_name) {
|
||||
std::unique_lock ulock(mutex);
|
||||
|
||||
const auto& model_name = model_config[fields::model_name];
|
||||
const auto& api_key = model_config[fields::api_key];
|
||||
|
||||
for(auto& coll_field : fields) {
|
||||
if (coll_field.name == field_name) {
|
||||
auto &coll_model_config = coll_field.embed[fields::model_config];
|
||||
if (!coll_model_config.contains(fields::model_name) || coll_model_config[fields::model_name] != model_name) {
|
||||
return Option<bool>(400, "`model_name` mismatch for api_key updation.");
|
||||
}
|
||||
|
||||
if (!coll_model_config.contains(fields::api_key)) {
|
||||
return Option<bool>(400, "Invalid model for api_key updation.");
|
||||
}
|
||||
|
||||
if (coll_model_config[fields::api_key] == api_key) {
|
||||
return Option<bool>(400, "trying to update with same api_key.");
|
||||
}
|
||||
|
||||
//update in remote embedder first the in collection
|
||||
auto update_op = EmbedderManager::get_instance().update_remote_model_apikey(coll_model_config, api_key);
|
||||
|
||||
if (!update_op.ok()) {
|
||||
return update_op;
|
||||
}
|
||||
|
||||
coll_model_config[fields::api_key] = api_key;
|
||||
embedding_fields[field_name].embed[fields::model_config][fields::api_key] = api_key;
|
||||
|
||||
auto persist_op = persist_collection_meta();
|
||||
if (!persist_op.ok()) {
|
||||
return persist_op;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::get_document_from_store(const uint32_t& seq_id,
|
||||
nlohmann::json& document, bool raw_doc) const {
|
||||
return get_document_from_store(get_seq_id_key(seq_id), document, raw_doc);
|
||||
@ -5241,11 +5281,12 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
|
||||
std::vector<field> del_fields;
|
||||
std::vector<field> addition_fields;
|
||||
std::vector<field> reindex_fields;
|
||||
std::vector<field> update_fields;
|
||||
|
||||
std::string this_fallback_field_type;
|
||||
|
||||
auto validate_op = validate_alter_payload(alter_payload, addition_fields, reindex_fields,
|
||||
del_fields, this_fallback_field_type);
|
||||
del_fields, update_fields, this_fallback_field_type);
|
||||
if(!validate_op.ok()) {
|
||||
LOG(INFO) << "Alter failed validation: " << validate_op.error();
|
||||
return validate_op;
|
||||
@ -5283,6 +5324,18 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
|
||||
}
|
||||
}
|
||||
|
||||
if(!update_fields.empty()) {
|
||||
for(const auto& f : update_fields) {
|
||||
if(f.embed.count(fields::from) != 0) {
|
||||
//it's an embed field
|
||||
auto op = update_apikey(f.embed[fields::model_config], f.name);
|
||||
if(!op.ok()) {
|
||||
return op;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hide credentials in the alter payload return
|
||||
for(auto& field_json : alter_payload["fields"]) {
|
||||
if(field_json[fields::embed].count(fields::model_config) != 0) {
|
||||
@ -5422,6 +5475,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
|
||||
std::vector<field>& addition_fields,
|
||||
std::vector<field>& reindex_fields,
|
||||
std::vector<field>& del_fields,
|
||||
std::vector<field>& update_fields,
|
||||
std::string& fallback_field_type) {
|
||||
if(!schema_changes.is_object()) {
|
||||
return Option<bool>(400, "Bad JSON.");
|
||||
@ -5614,9 +5668,28 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (found_field && field_it->embed.count(fields::from) != 0) {
|
||||
//embedded field, only api key updation is supported
|
||||
if(!kv.value().contains(fields::embed) || !kv.value()[fields::embed].is_object()) {
|
||||
return Option<bool>(400,
|
||||
"Missing or bad `embed` param.");
|
||||
}
|
||||
|
||||
if (!kv.value()[fields::embed].contains(fields::model_config) || !kv.value()[fields::embed][fields::model_config].is_object()) {
|
||||
return Option<bool>(400,
|
||||
"`model_config` should be an object containing `model_name` and `api_key`.");
|
||||
}
|
||||
|
||||
const auto &model_config = kv.value()[fields::embed][fields::model_config];
|
||||
if (!model_config.contains(fields::model_name) || !model_config.contains(fields::api_key) ||
|
||||
!model_config[fields::model_name].is_string() || !model_config[fields::api_key].is_string()) {
|
||||
return Option<bool>(400,
|
||||
"`model_config` should be an object containing `model_name` and `api_key` as string values.");
|
||||
}
|
||||
|
||||
field f(field_name, field_it->type, field_it->facet);
|
||||
f.embed = kv.value()[fields::embed];
|
||||
update_fields.push_back(f);
|
||||
} else {
|
||||
// partial update is not supported for now
|
||||
return Option<bool>(400, "Field `" + field_name + "` is already part of the schema: To "
|
||||
|
@ -2268,8 +2268,16 @@ bool CollectionManager::is_valid_api_key_collection(const std::vector<std::strin
|
||||
return api_collections.size() > 0 ? false : true;
|
||||
}
|
||||
|
||||
bool CollectionManager::update_collection_metadata(const std::string& collection, const nlohmann::json& metadata) {
|
||||
Option<bool> CollectionManager::update_collection_metadata(const std::string& collection, const nlohmann::json& metadata) {
|
||||
auto collection_ptr = get_collection(collection);
|
||||
if (collection_ptr == nullptr) {
|
||||
return Option<bool>(400, "failed to get collection.");
|
||||
}
|
||||
|
||||
collection_ptr->update_metadata(metadata);
|
||||
|
||||
std::string collection_meta_str;
|
||||
|
||||
auto collection_metakey = Collection::get_meta_key(collection);
|
||||
store->get(collection_metakey, collection_meta_str);
|
||||
|
||||
@ -2277,5 +2285,9 @@ bool CollectionManager::update_collection_metadata(const std::string& collection
|
||||
|
||||
collection_meta_json[Collection::COLLECTION_METADATA] = metadata;
|
||||
|
||||
return store->insert(collection_metakey, collection_meta_json.dump());
|
||||
if(store->insert(collection_metakey, collection_meta_json.dump())) {
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
return Option<bool>(400, "failed to insert into store.");
|
||||
}
|
@ -354,10 +354,13 @@ bool patch_update_collection(const std::shared_ptr<http_req>& req, const std::sh
|
||||
return false;
|
||||
}
|
||||
|
||||
collection->update_metadata(req_json["metadata"]);
|
||||
|
||||
//update in db
|
||||
collectionManager.update_collection_metadata(req->params["collection"], req_json["metadata"]);
|
||||
//update in collection metadata and store in db
|
||||
auto op = collectionManager.update_collection_metadata(req->params["collection"], req_json["metadata"]);
|
||||
if(!op.ok()) {
|
||||
res->set(op.code(), op.error());
|
||||
alter_in_progress = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(req_json.contains("fields")) {
|
||||
@ -369,6 +372,8 @@ bool patch_update_collection(const std::shared_ptr<http_req>& req, const std::sh
|
||||
alter_in_progress = false;
|
||||
return false;
|
||||
}
|
||||
// without this line, response will return full api key without being masked
|
||||
req_json["fields"] = alter_payload["fields"];
|
||||
}
|
||||
|
||||
alter_in_progress = false;
|
||||
|
@ -62,6 +62,32 @@ Option<bool> EmbedderManager::validate_and_init_remote_model(const nlohmann::jso
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> EmbedderManager::update_remote_model_apikey(const nlohmann::json &model_config, const std::string& new_apikey) {
|
||||
std::unique_lock<std::mutex> lock(text_embedders_mutex);
|
||||
const auto& model_key = RemoteEmbedder::get_model_key(model_config);
|
||||
|
||||
if(text_embedders.find(model_key) == text_embedders.end()) {
|
||||
return Option<bool>(404, "Text embedder was not found.");
|
||||
}
|
||||
|
||||
if(!text_embedders[model_key]->is_remote()) {
|
||||
return Option<bool>(400, "Text embedder is not valid.");
|
||||
}
|
||||
|
||||
if(!text_embedders[model_key]->update_remote_embedder_apikey(new_apikey)) {
|
||||
return Option<bool>(400, "Failed to update remote model api_key.");
|
||||
}
|
||||
|
||||
//update text embedder with new api_key and remove old entry
|
||||
auto updated_model_config = model_config;
|
||||
updated_model_config["api_key"] = new_apikey;
|
||||
const auto& updated_model_key = RemoteEmbedder::get_model_key(updated_model_config);
|
||||
text_embedders[updated_model_key] = text_embedders[model_key];
|
||||
text_embedders.erase(model_key);
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> EmbedderManager::validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims) {
|
||||
const std::string& model_name = model_config["model_name"].get<std::string>();
|
||||
Option<bool> public_model_op = EmbedderManager::get_instance().init_public_model(model_name);
|
||||
|
@ -139,6 +139,8 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
}
|
||||
|
||||
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
|
||||
std::shared_lock<std::shared_mutex> lock(mutex);
|
||||
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
@ -330,6 +332,8 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
}
|
||||
|
||||
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
|
||||
std::shared_lock<std::shared_mutex> lock(mutex);
|
||||
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -481,6 +485,8 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, siz
|
||||
}
|
||||
|
||||
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
|
||||
std::shared_lock<std::shared_mutex> lock(mutex);
|
||||
|
||||
nlohmann::json req_body;
|
||||
req_body["instances"] = nlohmann::json::array();
|
||||
nlohmann::json instance;
|
||||
|
Loading…
x
Reference in New Issue
Block a user