allow updating remote model api_key (#1944)
Some checks failed
tests / test (push) Has been cancelled

* allow updating remote model api_key

* add nullptr checks and unique_lock

* refactor field updation flow

* Revert "refactor field updation flow"

This reverts commit a6a9847be80ce3d8cee7ea944f5e9453091a949f.

* refactor api_key updation flow

* add locks in remote embedder, update api_key after validation

* fix parsing of embed field

* fix wrong params and parse

* update minor changes, update entry in text_embedders

* add comment

* update embeded_fields with new api_key

* add lock for updating embedded fields

---------

Co-authored-by: Kishore Nallan <kishorenc@gmail.com>
This commit is contained in:
Krunal Gandhi 2024-09-25 03:28:32 +00:00 committed by GitHub
parent 87a48a5ef3
commit b256736986
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 158 additions and 9 deletions

View File

@ -266,6 +266,7 @@ private:
std::vector<field>& addition_fields,
std::vector<field>& reindex_fields,
std::vector<field>& del_fields,
std::vector<field>& update_fields,
std::string& fallback_field_type);
void process_filter_overrides(std::vector<const override_t*>& filter_overrides,
@ -425,6 +426,8 @@ public:
void update_metadata(const nlohmann::json& meta);
Option<bool> update_apikey(const nlohmann::json& model_config, const std::string& field_name);
Option<doc_seq_id_t> to_doc(const std::string& json_str, nlohmann::json& document,
const index_operation_t& operation,
const DIRTY_VALUES dirty_values,

View File

@ -232,5 +232,5 @@ public:
bool is_valid_api_key_collection(const std::vector<std::string>& api_key_collections, Collection* coll) const;
bool update_collection_metadata(const std::string& collection, const nlohmann::json& metadata);
Option<bool> update_collection_metadata(const std::string& collection, const nlohmann::json& metadata);
};

View File

@ -84,6 +84,8 @@ public:
Option<bool> validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> update_remote_model_apikey(const nlohmann::json& model_config, const std::string& new_apikey);
std::unordered_map<std::string, std::shared_ptr<TextEmbedder>> _get_text_embedders() {
return text_embedders;
}

View File

@ -37,6 +37,11 @@ class TextEmbedder {
const TokenizerType get_tokenizer_type() {
return tokenizer_->get_tokenizer_type();
}
bool update_remote_embedder_apikey(const std::string& api_key) {
return remote_embedder_->update_api_key(api_key);
}
private:
std::shared_ptr<Ort::Session> session_;
std::shared_ptr<Ort::Env> env_;

View File

@ -28,6 +28,7 @@ class RemoteEmbedder {
protected:
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
static inline ReplicationState* raft_server = nullptr;
std::shared_mutex mutex;
public:
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
@ -39,6 +40,7 @@ class RemoteEmbedder {
raft_server = rs;
}
virtual ~RemoteEmbedder() = default;
virtual bool update_api_key(const std::string& api_key) = 0;
};
@ -51,6 +53,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
bool has_custom_dims;
size_t num_dims;
std::string openai_url = "https://api.openai.com";
static std::string get_openai_create_embedding_url(const std::string& openai_url) {
return openai_url.back() == '/' ? openai_url + OPENAI_CREATE_EMBEDDING : openai_url + "/" + OPENAI_CREATE_EMBEDDING;
}
@ -62,6 +65,12 @@ class OpenAIEmbedder : public RemoteEmbedder {
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
bool update_api_key(const std::string& apikey) override {
std::lock_guard<std::shared_mutex> lock(mutex);
api_key = apikey;
return true;
}
};
@ -72,7 +81,6 @@ class GoogleEmbedder : public RemoteEmbedder {
inline static constexpr short GOOGLE_EMBEDDING_DIM = 768;
inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=";
std::string google_api_key;
public:
GoogleEmbedder(const std::string& google_api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
@ -81,6 +89,11 @@ class GoogleEmbedder : public RemoteEmbedder {
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
bool update_api_key(const std::string& apikey) override {
std::lock_guard<std::shared_mutex> lock(mutex);
google_api_key = apikey;
return true;
}
};
@ -93,6 +106,7 @@ class GCPEmbedder : public RemoteEmbedder {
std::string client_id;
std::string client_secret;
std::string model_name;
inline static const std::string GCP_EMBEDDING_BASE_URL = "https://us-central1-aiplatform.googleapis.com/v1/projects/";
inline static const std::string GCP_EMBEDDING_PATH = "/locations/us-central1/publishers/google/models/";
inline static const std::string GCP_EMBEDDING_PREDICT = ":predict";
@ -110,6 +124,9 @@ class GCPEmbedder : public RemoteEmbedder {
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
bool update_api_key(const std::string& api_key) override {
return true;
}
};

View File

@ -4813,6 +4813,46 @@ void Collection::update_metadata(const nlohmann::json& meta) {
metadata = meta;
}
Option<bool> Collection::update_apikey(const nlohmann::json& model_config, const std::string& field_name) {
std::unique_lock ulock(mutex);
const auto& model_name = model_config[fields::model_name];
const auto& api_key = model_config[fields::api_key];
for(auto& coll_field : fields) {
if (coll_field.name == field_name) {
auto &coll_model_config = coll_field.embed[fields::model_config];
if (!coll_model_config.contains(fields::model_name) || coll_model_config[fields::model_name] != model_name) {
return Option<bool>(400, "`model_name` mismatch for api_key updation.");
}
if (!coll_model_config.contains(fields::api_key)) {
return Option<bool>(400, "Invalid model for api_key updation.");
}
if (coll_model_config[fields::api_key] == api_key) {
return Option<bool>(400, "trying to update with same api_key.");
}
//update in remote embedder first the in collection
auto update_op = EmbedderManager::get_instance().update_remote_model_apikey(coll_model_config, api_key);
if (!update_op.ok()) {
return update_op;
}
coll_model_config[fields::api_key] = api_key;
embedding_fields[field_name].embed[fields::model_config][fields::api_key] = api_key;
auto persist_op = persist_collection_meta();
if (!persist_op.ok()) {
return persist_op;
}
}
}
return Option<bool>(true);
}
Option<bool> Collection::get_document_from_store(const uint32_t& seq_id,
nlohmann::json& document, bool raw_doc) const {
return get_document_from_store(get_seq_id_key(seq_id), document, raw_doc);
@ -5241,11 +5281,12 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
std::vector<field> del_fields;
std::vector<field> addition_fields;
std::vector<field> reindex_fields;
std::vector<field> update_fields;
std::string this_fallback_field_type;
auto validate_op = validate_alter_payload(alter_payload, addition_fields, reindex_fields,
del_fields, this_fallback_field_type);
del_fields, update_fields, this_fallback_field_type);
if(!validate_op.ok()) {
LOG(INFO) << "Alter failed validation: " << validate_op.error();
return validate_op;
@ -5283,6 +5324,18 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
}
}
if(!update_fields.empty()) {
for(const auto& f : update_fields) {
if(f.embed.count(fields::from) != 0) {
//it's an embed field
auto op = update_apikey(f.embed[fields::model_config], f.name);
if(!op.ok()) {
return op;
}
}
}
}
// hide credentials in the alter payload return
for(auto& field_json : alter_payload["fields"]) {
if(field_json[fields::embed].count(fields::model_config) != 0) {
@ -5422,6 +5475,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
std::vector<field>& addition_fields,
std::vector<field>& reindex_fields,
std::vector<field>& del_fields,
std::vector<field>& update_fields,
std::string& fallback_field_type) {
if(!schema_changes.is_object()) {
return Option<bool>(400, "Bad JSON.");
@ -5614,9 +5668,28 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
}
}
}
} else if (found_field && field_it->embed.count(fields::from) != 0) {
//embedded field, only api key updation is supported
if(!kv.value().contains(fields::embed) || !kv.value()[fields::embed].is_object()) {
return Option<bool>(400,
"Missing or bad `embed` param.");
}
if (!kv.value()[fields::embed].contains(fields::model_config) || !kv.value()[fields::embed][fields::model_config].is_object()) {
return Option<bool>(400,
"`model_config` should be an object containing `model_name` and `api_key`.");
}
const auto &model_config = kv.value()[fields::embed][fields::model_config];
if (!model_config.contains(fields::model_name) || !model_config.contains(fields::api_key) ||
!model_config[fields::model_name].is_string() || !model_config[fields::api_key].is_string()) {
return Option<bool>(400,
"`model_config` should be an object containing `model_name` and `api_key` as string values.");
}
field f(field_name, field_it->type, field_it->facet);
f.embed = kv.value()[fields::embed];
update_fields.push_back(f);
} else {
// partial update is not supported for now
return Option<bool>(400, "Field `" + field_name + "` is already part of the schema: To "

View File

@ -2268,8 +2268,16 @@ bool CollectionManager::is_valid_api_key_collection(const std::vector<std::strin
return api_collections.size() > 0 ? false : true;
}
bool CollectionManager::update_collection_metadata(const std::string& collection, const nlohmann::json& metadata) {
Option<bool> CollectionManager::update_collection_metadata(const std::string& collection, const nlohmann::json& metadata) {
auto collection_ptr = get_collection(collection);
if (collection_ptr == nullptr) {
return Option<bool>(400, "failed to get collection.");
}
collection_ptr->update_metadata(metadata);
std::string collection_meta_str;
auto collection_metakey = Collection::get_meta_key(collection);
store->get(collection_metakey, collection_meta_str);
@ -2277,5 +2285,9 @@ bool CollectionManager::update_collection_metadata(const std::string& collection
collection_meta_json[Collection::COLLECTION_METADATA] = metadata;
return store->insert(collection_metakey, collection_meta_json.dump());
if(store->insert(collection_metakey, collection_meta_json.dump())) {
return Option<bool>(true);
}
return Option<bool>(400, "failed to insert into store.");
}

View File

@ -354,10 +354,13 @@ bool patch_update_collection(const std::shared_ptr<http_req>& req, const std::sh
return false;
}
collection->update_metadata(req_json["metadata"]);
//update in db
collectionManager.update_collection_metadata(req->params["collection"], req_json["metadata"]);
//update in collection metadata and store in db
auto op = collectionManager.update_collection_metadata(req->params["collection"], req_json["metadata"]);
if(!op.ok()) {
res->set(op.code(), op.error());
alter_in_progress = false;
return false;
}
}
if(req_json.contains("fields")) {
@ -369,6 +372,8 @@ bool patch_update_collection(const std::shared_ptr<http_req>& req, const std::sh
alter_in_progress = false;
return false;
}
// without this line, response will return full api key without being masked
req_json["fields"] = alter_payload["fields"];
}
alter_in_progress = false;

View File

@ -62,6 +62,32 @@ Option<bool> EmbedderManager::validate_and_init_remote_model(const nlohmann::jso
return Option<bool>(true);
}
Option<bool> EmbedderManager::update_remote_model_apikey(const nlohmann::json &model_config, const std::string& new_apikey) {
std::unique_lock<std::mutex> lock(text_embedders_mutex);
const auto& model_key = RemoteEmbedder::get_model_key(model_config);
if(text_embedders.find(model_key) == text_embedders.end()) {
return Option<bool>(404, "Text embedder was not found.");
}
if(!text_embedders[model_key]->is_remote()) {
return Option<bool>(400, "Text embedder is not valid.");
}
if(!text_embedders[model_key]->update_remote_embedder_apikey(new_apikey)) {
return Option<bool>(400, "Failed to update remote model api_key.");
}
//update text embedder with new api_key and remove old entry
auto updated_model_config = model_config;
updated_model_config["api_key"] = new_apikey;
const auto& updated_model_key = RemoteEmbedder::get_model_key(updated_model_config);
text_embedders[updated_model_key] = text_embedders[model_key];
text_embedders.erase(model_key);
return Option<bool>(true);
}
Option<bool> EmbedderManager::validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims) {
const std::string& model_name = model_config["model_name"].get<std::string>();
Option<bool> public_model_op = EmbedderManager::get_instance().init_public_model(model_name);

View File

@ -139,6 +139,8 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
}
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
std::shared_lock<std::shared_mutex> lock(mutex);
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
@ -330,6 +332,8 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
}
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
std::shared_lock<std::shared_mutex> lock(mutex);
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
@ -481,6 +485,8 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, siz
}
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
std::shared_lock<std::shared_mutex> lock(mutex);
nlohmann::json req_body;
req_body["instances"] = nlohmann::json::array();
nlohmann::json instance;